From fff10809bfd78b93743eec09e420b6c9ef2e3780 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 15 Jun 2025 02:48:00 -0700 Subject: [PATCH] Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210) --- .../layers/attention/flashinfer_backend.py | 3 +- .../attention/flashinfer_mla_backend.py | 4 +- .../srt/layers/attention/triton_backend.py | 3 +- python/sglang/srt/mem_cache/memory_pool.py | 61 --- python/sglang/srt/speculative/eagle_utils.py | 486 ++++-------------- python/sglang/srt/speculative/eagle_worker.py | 174 ++----- test/srt/test_eagle_infer_b.py | 66 --- 7 files changed, 150 insertions(+), 647 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f11de5641..876141083 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1049,13 +1049,14 @@ class FlashInferMultiStepDraftBackend: kv_indices_buffer, self.kv_indptr, forward_batch.positions, + num_seqs, + self.topk, self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), next_power_of_2(bs), - self.page_size, ) assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 19fa09818..c4192a715 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -789,7 +789,6 @@ class FlashInferMLAMultiStepDraftBackend: # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] - self.page_size = model_runner.server_args.page_size def common_template( self, @@ -810,13 +809,14 @@ class FlashInferMLAMultiStepDraftBackend: kv_indices_buffer, self.kv_indptr, forward_batch.positions, + num_seqs, + self.topk, self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), next_power_of_2(bs), - self.page_size, ) assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 970ceb999..0a3aef9c3 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -784,13 +784,14 @@ class TritonMultiStepDraftBackend: kv_indices_buffer, self.kv_indptr, forward_batch.positions, + num_seqs, + self.topk, self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), next_power_of_2(bs), - self.page_size, ) for i in range(self.speculative_num_steps): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 1e823be10..bac310da3 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache): for _ in range(self.layer_num) ] - self.data_ptrs = torch.tensor( - [x.data_ptr() for x in self.k_buffer + self.v_buffer], - dtype=torch.uint64, - device=self.device, - ) - self.data_strides = torch.tensor( - [ - np.prod(x.shape[1:]) * x.dtype.itemsize - for x in self.k_buffer + self.v_buffer - ], - device=self.device, - ) - def _clear_buffers(self): del self.k_buffer del self.v_buffer @@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache): self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.v_buffer[layer_id - self.start_layer][loc] = cache_v - def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): - copy_all_layer_kv_cache[(len(self.data_ptrs),)]( - self.data_ptrs, - self.data_strides, - tgt_loc, - src_loc, - len(tgt_loc), - next_power_of_2(len(tgt_loc)), - ) - @triton.jit def set_mla_kv_buffer_kernel( @@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache): def transfer_per_layer(self, indices, flat_data, layer_id): pass - - -@triton.jit -def copy_all_layer_kv_cache( - data_ptrs, - strides, - tgt_loc_ptr, - src_loc_ptr, - num_locs, - num_locs_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 128 - - bid = tl.program_id(0) - stride = tl.load(strides + bid) - - data_ptr = tl.load(data_ptrs + bid) - data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8)) - - num_locs_offset = tl.arange(0, num_locs_upper) - tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) - src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) - - # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks - # because this copy is an inplace operation. - - num_loop = tl.cdiv(stride, BLOCK_SIZE) - for i in range(num_loop): - copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :] - value = tl.load( - data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask - ) - tl.store( - data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :], - value, - mask=mask, - ) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index cedc2ee88..8bb1222da 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -67,6 +67,8 @@ class EagleDraftInput: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None + all_padding_lens: Optional[torch.Tensor] = None + def prepare_for_extend(self, batch: ScheduleBatch): # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) @@ -91,7 +93,6 @@ class EagleDraftInput: batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.return_logprob = False - batch.return_hidden_states = False self.capture_hidden_mode = CaptureHiddenMode.LAST self.accept_length.add_(1) @@ -115,8 +116,10 @@ class EagleDraftInput: req_to_token: torch.Tensor, ): bs = self.accept_length.numel() + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -136,6 +139,7 @@ class EagleDraftInput: kv_indices, req_to_token.size(1), ) + return kv_indices, cum_kv_seq_len, qo_indptr, None def filter_batch(self, new_indices: torch.Tensor): @@ -266,7 +270,7 @@ class EagleVerifyInput: logits_output: torch.Tensor, token_to_kv_pool_allocator: TokenToKVPoolAllocator, page_size: int, - vocab_mask: Optional[torch.Tensor] = None, # For grammar + vocab_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Verify and find accepted tokens based on logits output and batch @@ -290,14 +294,6 @@ class EagleVerifyInput: ) accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") - # Apply the custom logit processors if registered in the sampling info. - if sampling_info.has_custom_logit_processor: - apply_custom_logit_processor( - logits_output.next_token_logits, - sampling_info, - num_tokens_in_batch=self.draft_token_num, - ) - # Apply penalty if sampling_info.penalizer_orchestrator.is_required: # This is a relaxed version of penalties for speculative decoding. @@ -359,13 +355,7 @@ class EagleVerifyInput: draft_probs = torch.zeros( target_probs.shape, dtype=torch.float32, device="cuda" ) - - # coins for rejection sampling coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") - # coins for final sampling - coins_for_final_sampling = torch.rand( - (bs,), dtype=torch.float32, device="cuda" - ) tree_speculative_sampling_target_only( predicts=predict, # mutable accept_index=accept_index, # mutable @@ -375,7 +365,6 @@ class EagleVerifyInput: retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), uniform_samples=coins, - # uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, threshold_single=global_server_args_dict[ @@ -398,8 +387,8 @@ class EagleVerifyInput: spec_steps=self.spec_steps, ) + new_accept_index = [] unfinished_index = [] - unfinished_accept_index = [] accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() has_finished = False @@ -407,10 +396,12 @@ class EagleVerifyInput: # Iterate every accepted token and check if req has finished after append the token # should be checked BEFORE free kv cache slots for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): + new_accept_index_ = [] for j, idx in enumerate(accept_index_row): if idx == -1: break id = predict_cpu[idx] + # if not found_finished: req.output_ids.append(id) req.check_finished() if req.finished(): @@ -419,6 +410,8 @@ class EagleVerifyInput: accept_index[i, j + 1 :] = -1 break else: + new_accept_index_.append(idx) + # update grammar state if req.grammar is not None: try: req.grammar.accept_token(id) @@ -428,104 +421,50 @@ class EagleVerifyInput: ) raise e if not req.finished(): + new_accept_index.extend(new_accept_index_) unfinished_index.append(i) - if idx == -1: - unfinished_accept_index.append(accept_index[i, :j]) - else: - unfinished_accept_index.append(accept_index[i]) req.spec_verify_ct += 1 if has_finished: accept_length = (accept_index != -1).sum(dim=1) - 1 # Free the KV cache for unaccepted tokens - # TODO: fuse them accept_index = accept_index[accept_index != -1] verified_id = predict[accept_index] evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False - if page_size == 1: - # TODO: boolean array index leads to a device sync. Remove it. - token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) - else: - if self.topk == 1: - # Only evict full empty page. Do not evict partial empty page - align_evict_mask_to_page_size[len(batch.seq_lens),]( - batch.seq_lens, - evict_mask, - page_size, - self.draft_token_num, - next_power_of_2(self.draft_token_num), - ) - token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) - else: - # Shift the accepted tokens to the beginning. - # Only evict the last part - src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc( - batch.seq_lens, - batch.out_cache_loc, - accept_index, - accept_length, - self.draft_token_num, - page_size, - ) - to_free_slots = torch.empty( - (to_free_num_slots.sum().item(),), - dtype=torch.int64, - device=to_free_num_slots.device, - ) + if page_size != 1: + align_evict_mask_to_page_size[len(batch.seq_lens),]( + batch.seq_lens, + evict_mask, + page_size, + self.draft_token_num, + next_power_of_2(self.draft_token_num), + ) - # out_cache_loc: [0 1 2, 3 4 5, 6 7 8] - # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1] - # tgt_cache_loc: [0 1 , 3 4 , 6 ] - # to_free_slots: [ 2, 5, 7 8] - # to_free_slots also needs to be page-aligned without the first partial page - # - # split each row of out_cache_loc into two parts. - # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1 - # 2. the second part goes to to_free_slots. - get_target_cache_loc[(bs,)]( - tgt_cache_loc, - to_free_slots, - accept_length, - to_free_num_slots, - batch.out_cache_loc, - self.draft_token_num, - next_power_of_2(self.draft_token_num), - next_power_of_2(bs), - ) - - # Free the kv cache - token_to_kv_pool_allocator.free(to_free_slots) - - # Copy the kv cache - batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache( - tgt_cache_loc, src_cache_loc - ) + token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) # Construct EagleVerifyOutput if not has_finished: - if page_size == 1 or self.topk == 1: - batch.out_cache_loc = batch.out_cache_loc[accept_index] - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + accept_length + 1, - batch.out_cache_loc, - batch.req_to_token_pool.req_to_token.shape[1], - next_power_of_2(bs), - ) - else: - batch.out_cache_loc = tgt_cache_loc + batch.out_cache_loc = batch.out_cache_loc[accept_index] + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) batch.seq_lens.add_(accept_length + 1) + accept_length_cpu = accept_length.tolist() draft_input = EagleDraftInput() draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] draft_input.verified_id = verified_id draft_input.accept_length = accept_length - draft_input.accept_length_cpu = accept_length.tolist() + draft_input.accept_length_cpu = accept_length_cpu draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices @@ -533,66 +472,47 @@ class EagleVerifyInput: draft_input=draft_input, logits_output=logits_output, verified_id=verified_id, - accept_length_per_req_cpu=draft_input.accept_length_cpu, + accept_length_per_req_cpu=accept_length_cpu, accepted_indices=accept_index, ) else: - if page_size == 1 or self.topk == 1: - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + accept_length + 1, - batch.out_cache_loc[accept_index], - batch.req_to_token_pool.req_to_token.shape[1], - next_power_of_2(bs), - ) - batch.seq_lens.add_(accept_length + 1) - + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() + draft_input = EagleDraftInput() - if len(unfinished_accept_index) > 0: - unfinished_accept_index = torch.cat(unfinished_accept_index) - unfinished_index_device = torch.tensor( - unfinished_index, dtype=torch.int64, device=predict.device - ) - draft_input_accept_length_cpu = [ + if len(new_accept_index) > 0: + new_accept_index = torch.tensor(new_accept_index, device="cuda") + unfinished_index_device = torch.tensor(unfinished_index, device="cuda") + draft_input.hidden_states = batch.spec_info.hidden_states[ + new_accept_index + ] + draft_input.verified_id = predict[new_accept_index] + draft_input.accept_length_cpu = [ accept_length_cpu[i] for i in unfinished_index ] - if page_size == 1 or self.topk == 1: - batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index] - else: - batch.out_cache_loc = torch.empty( - len(unfinished_index) + sum(draft_input_accept_length_cpu), - dtype=torch.int64, - device=predict.device, - ) - accept_length_filter = create_accept_length_filter( - accept_length, - unfinished_index_device, - batch.seq_lens, - ) - filter_finished_cache_loc_kernel[(bs,)]( - batch.out_cache_loc, - tgt_cache_loc, - accept_length, - accept_length_filter, - next_power_of_2(bs), - next_power_of_2(self.draft_token_num), - ) - - draft_input.hidden_states = batch.spec_info.hidden_states[ - unfinished_accept_index - ] - draft_input.verified_id = predict[unfinished_accept_index] - draft_input.accept_length_cpu = draft_input_accept_length_cpu draft_input.accept_length = accept_length[unfinished_index_device] - draft_input.seq_lens_for_draft_extend = batch.seq_lens[ - unfinished_index_device - ] - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ - unfinished_index_device - ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[ + unfinished_index_device + ] + draft_input.req_pool_indices_for_draft_extend = ( + batch.req_pool_indices[unfinished_index_device] + ) + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens + draft_input.req_pool_indices_for_draft_extend = ( + batch.req_pool_indices + ) + batch.out_cache_loc = batch.out_cache_loc[new_accept_index] return EagleVerifyOutput( draft_input=draft_input, @@ -669,75 +589,36 @@ def assign_draft_cache_locs( req_pool_indices, req_to_token, seq_lens, - extend_lens, - num_new_pages_per_topk, out_cache_loc, pool_len: tl.constexpr, topk: tl.constexpr, speculative_num_steps: tl.constexpr, page_size: tl.constexpr, - bs_upper: tl.constexpr, - iter_upper: tl.constexpr, ): - BLOCK_SIZE: tl.constexpr = 128 + BLOCK_SIZE: tl.constexpr = 32 pid = tl.program_id(axis=0) + kv_start = tl.load(seq_lens + pid) if page_size == 1 or topk == 1: - copy_len = topk * speculative_num_steps + kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps else: - bs_offset = tl.arange(0, bs_upper) - copy_len = tl.load(extend_lens + pid) - cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid)) - out_cache_ptr = out_cache_loc + cum_copy_len + prefix_len = tl.load(seq_lens + pid) + last_page_len = prefix_len % page_size + num_new_page = ( + last_page_len + speculative_num_steps + page_size - 1 + ) // page_size + kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk) - # Part 1: Copy from out_cache_loc to req_to_token - kv_start = tl.load(seq_lens + pid) token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len - num_loop = tl.cdiv(copy_len, BLOCK_SIZE) + + num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) for i in range(num_loop): - copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = copy_offset < copy_len - data = tl.load(out_cache_ptr + copy_offset, mask=mask) - tl.store(token_pool + kv_start + copy_offset, data, mask=mask) - - if page_size == 1 or topk == 1: - return - - # Part 2: Copy the indices for the last partial page - prefix_len = tl.load(seq_lens + pid) - last_page_len = prefix_len % page_size - offsets = tl.arange(0, page_size) - mask = offsets < last_page_len - num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) - prefix_base = token_pool + prefix_len - last_page_len - - for topk_id in range(topk): - value = tl.load(prefix_base + offsets, mask=mask) - tl.store( - prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets, - value, - mask=mask, - ) - - # Part 3: Remove the padding in out_cache_loc - iter_offest = tl.arange(0, iter_upper) - for topk_id in range(topk): - indices = tl.load( - prefix_base - + topk_id * num_new_pages_per_topk_ * page_size - + last_page_len - + iter_offest, - mask=iter_offest < speculative_num_steps, - ) - tl.store( - out_cache_loc - + pid * topk * speculative_num_steps - + topk_id * speculative_num_steps - + iter_offest, - indices, - mask=iter_offest < speculative_num_steps, - ) + save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) @triton.jit @@ -748,23 +629,20 @@ def generate_draft_decode_kv_indices( kv_indices, kv_indptr, positions, + num_seqs: tl.constexpr, + topk: tl.constexpr, pool_len: tl.constexpr, kv_indices_stride: tl.constexpr, kv_indptr_stride: tl.constexpr, bs_upper: tl.constexpr, iter_upper: tl.constexpr, num_tokens_upper: tl.constexpr, - page_size: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 128 iters = tl.program_id(axis=0) bid = tl.program_id(axis=1) topk_id = tl.program_id(axis=2) - num_steps = tl.num_programs(axis=0) - num_seqs = tl.num_programs(axis=1) - topk = tl.num_programs(axis=2) - kv_indices += kv_indices_stride * iters kv_indptr += kv_indptr_stride * iters iters += 1 @@ -774,7 +652,6 @@ def generate_draft_decode_kv_indices( seq_len = tl.load(paged_kernel_lens + bid) cum_seq_len = tl.sum(seq_lens) - # Update kv_indices kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) kv_ptr = kv_indices + kv_offset token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len @@ -788,26 +665,10 @@ def generate_draft_decode_kv_indices( kv_offset += BLOCK_SIZE extend_offset = tl.arange(0, iter_upper) - if page_size == 1 or topk == 1: - extend_data = tl.load( - token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper), - mask=extend_offset < iters, - ) - else: - prefix_len = seq_len - last_page_len = prefix_len % page_size - num_new_pages_per_topk = ( - last_page_len + num_steps + page_size - 1 - ) // page_size - prefix_base = seq_len // page_size * page_size - start = ( - prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len - ) - extend_data = tl.load( - token_pool_ptr + start + extend_offset, - mask=extend_offset < iters, - ) - + extend_data = tl.load( + token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, + mask=extend_offset < iters, + ) tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) # Update kv_indptr @@ -846,116 +707,6 @@ def align_evict_mask_to_page_size( tl.store(evict_mask + bid * num_draft_tokens + i, False) -@triton.jit -def get_target_cache_loc( - tgt_cache_loc, - to_free_slots, - accept_length, - to_free_num_slots, - out_cache_loc, - num_verify_tokens: tl.constexpr, - num_verify_tokens_upper: tl.constexpr, - bs_upper: tl.constexpr, -): - bid = tl.program_id(axis=0) - offset = tl.arange(0, num_verify_tokens_upper) - bs_offset = tl.arange(0, bs_upper) - - # write the first part to tgt_cache_loc - accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) - tgt_cache_loc_start = tl.sum(accept_len_all) + bid - copy_len = tl.load(accept_length + bid) + 1 - out_cache_loc_row = tl.load( - out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len - ) - tl.store( - tgt_cache_loc + tgt_cache_loc_start + offset, - out_cache_loc_row, - mask=offset < copy_len, - ) - - # write the second part to to_free_num_pages - to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) - to_free_num_slots_cur = tl.load(to_free_num_slots + bid) - out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur - to_free_slots_start = tl.sum(to_free_num_slots_all) - - copy_len = to_free_num_slots_cur - out_cache_loc_row = tl.load( - out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset, - mask=offset < copy_len, - ) - tl.store( - to_free_slots + to_free_slots_start + offset, - out_cache_loc_row, - mask=offset < copy_len, - ) - - -@torch.compile(dynamic=True) -def get_src_tgt_cache_loc( - seq_lens: torch.Tensor, - out_cache_loc: torch.Tensor, - accept_index: torch.Tensor, - accept_length: torch.Tensor, - draft_token_num: int, - page_size: int, -): - src_cache_loc = out_cache_loc[accept_index] - tgt_cache_loc = torch.empty_like(src_cache_loc) - extended_len = seq_lens + draft_token_num - keep_len = torch.minimum( - (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size, - extended_len, - ) - to_free_num_slots = extended_len - keep_len - return src_cache_loc, tgt_cache_loc, to_free_num_slots - - -@triton.jit -def filter_finished_cache_loc_kernel( - out_cache_loc, - tgt_cache_loc, - accept_length, - accept_length_filter, - bs_upper: tl.constexpr, - num_verify_tokens_upper: tl.constexpr, -): - bid = tl.program_id(0) - bs_offset = tl.arange(0, bs_upper) - - accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) - old_start = tl.sum(accept_length_all) + bid - - accept_length_filter_all = tl.load( - accept_length_filter + bs_offset, mask=bs_offset < bid - ) - new_start = tl.sum(accept_length_filter_all) - - copy_len = tl.load(accept_length_filter + bid) - copy_offset = tl.arange(0, num_verify_tokens_upper) - value = tl.load( - tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len - ) - tl.store( - out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len - ) - - -@torch.compile(dynamic=True) -def create_accept_length_filter( - accept_length: torch.Tensor, - unfinished_index_device: torch.Tensor, - seq_lens: torch.Tensor, -): - accept_length_filter = torch.zeros_like(accept_length) - accept_length_filter[unfinished_index_device] = ( - accept_length[unfinished_index_device] + 1 - ) - seq_lens.add_(accept_length + 1) - return accept_length_filter - - @torch.compile(dynamic=True) def select_top_k_tokens( i: int, @@ -1005,16 +756,6 @@ def select_top_k_tokens( return input_ids, hidden_states, scores, tree_info -def fast_topk_torch(values, topk, dim): - if topk == 1: - # Use max along the specified dimension to get both value and index - max_value, max_index = torch.max(values, dim=dim) - return max_value.unsqueeze(1), max_index.unsqueeze(1) - else: - # Use topk for efficiency with larger k values - return torch.topk(values, topk, dim=dim) - - def _generate_simulated_accept_index( accept_index, predict, @@ -1024,35 +765,15 @@ def _generate_simulated_accept_index( spec_steps, ): simulate_acc_len_float = float(simulate_acc_len) - if SIMULATE_ACC_METHOD == "multinomial": - simulated_values = torch.normal( - mean=simulate_acc_len_float, - std=1.0, - size=(1,), - device="cpu", - ) - # clamp simulated values to be between 1 and self.spec_steps - simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) - simulate_acc_len = int(simulated_values.round().item()) - elif SIMULATE_ACC_METHOD == "match-expected": - # multinomial sampling does not match the expected length - # we keep it for the sake of compatibility of existing tests - # but it's better to use "match-expected" for the cases that need to - # match the expected length, One caveat is that this will only sample - # either round down or round up of the expected length - simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float)) - lower = int(simulate_acc_len_float // 1) - upper = lower + 1 if lower < spec_steps + 1 else lower - if lower == upper: - simulate_acc_len = lower - else: - weight_upper = simulate_acc_len_float - lower - weight_lower = 1.0 - weight_upper - probs = torch.tensor([weight_lower, weight_upper], device="cpu") - sampled_index = torch.multinomial(probs, num_samples=1) - simulate_acc_len = lower if sampled_index == 0 else upper - else: - raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}") + simulated_values = torch.normal( + mean=simulate_acc_len_float, + std=1.0, + size=(1,), + device="cpu", + ) + # clamp simulated values to be between 1 and self.spec_steps + simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps) + simulate_acc_len = int(simulated_values.round().item()) accept_indx_first_col = accept_index[:, 0].view(-1, 1) sim_accept_index = torch.full( @@ -1143,9 +864,9 @@ def generate_token_bitmask( """ Generate the logit mask for structured output. Draft model's token can be either valid or invalid with respect to the grammar. - We need to perform DFS to - 1. figure out which tokens are accepted by the grammar. - 2. if so, what is the corresponding logit mask. + We need to perform DFS to figure out: + 1. which tokens are accepted by the grammar + 2. what is the corresponding logit mask. """ num_draft_tokens = draft_tokens_cpu.shape[-1] @@ -1162,7 +883,6 @@ def generate_token_bitmask( device="cpu", ) grammar = req.grammar - s = time.perf_counter() traverse_tree( retrieve_next_token_cpu[i], retrieve_next_sibling_cpu[i], @@ -1172,12 +892,6 @@ def generate_token_bitmask( i * num_draft_tokens : (i + 1) * num_draft_tokens ], ) - tree_traverse_time = time.perf_counter() - s - if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD: - logger.warning( - f"Bit mask generation took {tree_traverse_time} seconds with " - f"grammar: {req.grammar}" - ) verify_input.grammar = grammar return allocate_token_bitmask diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 747bd8e98..0597ad4e0 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import ( EagleVerifyInput, EagleVerifyOutput, assign_draft_cache_locs, - fast_topk, generate_token_bitmask, select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import ( - empty_context, - get_available_gpu_memory, - is_cuda, - next_power_of_2, -) +from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda if is_cuda(): from sgl_kernel import segment_packbits @@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker): self.init_attention_backend() self.init_cuda_graphs() - # Some dummy tensors - self.num_new_pages_per_topk = torch.empty( - (), dtype=torch.int64, device=self.device - ) - self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device) - def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners if self.server_args.attention_backend == "flashinfer": @@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker): self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." + f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." ) # Capture extend @@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker): ) after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." + f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." ) @property @@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker): A tuple of the final logit output of the target model, next tokens accepted, the batch id (used for overlap schedule), and number of accepted tokens. """ + if batch.forward_mode.is_decode(): with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) @@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker): ) # Allocate cache locations - # Layout of the out_cache_loc - # [ topk 0 ] [ topk 1 ] - # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2] if self.page_size == 1: out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( - num_seqs * self.speculative_num_steps * self.topk, backup_state=True + num_seqs * self.topk * self.speculative_num_steps, backup_state=True ) else: if self.topk == 1: - prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1( - batch.req_to_token_pool.req_to_token, - batch.req_pool_indices, - batch.seq_lens, - self.speculative_num_steps, - ) + prefix_lens = batch.seq_lens + seq_lens = prefix_lens + self.speculative_num_steps extend_num_tokens = num_seqs * self.speculative_num_steps else: # In this case, the last partial page needs to be duplicated. @@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker): # "x" means speculative draft tokens # "." means padded tokens - # TODO(lmzheng): The current implementation is still a fake support - # for page size > 1. In the `assign_draft_cache_locs` below, - # we directly move the indices instead of the real kv cache. - # This only works when the kernel backend runs with page size = 1. - # If the kernel backend runs with page size > 1, we need to - # duplicate the real KV cache. The overhead of duplicating KV - # cache seems okay because the draft KV cache only has one layer. - # see a related copy operation in MHATokenToKVPool::move_kv_cache. - - ( - prefix_lens, - seq_lens, - last_loc, - self.num_new_pages_per_topk, - self.extend_lens, - ) = get_last_loc_large_page_size_large_top_k( - batch.req_to_token_pool.req_to_token, - batch.req_pool_indices, - batch.seq_lens, - self.speculative_num_steps, - self.topk, - self.page_size, + # TODO: fuse these ops + prefix_lens = batch.seq_lens + last_page_lens = prefix_lens % self.page_size + num_new_pages = ( + last_page_lens + self.speculative_num_steps + self.page_size - 1 + ) // self.page_size + seq_lens = ( + prefix_lens // self.page_size * self.page_size + + num_new_pages * (self.page_size * self.topk) ) + extend_num_tokens = torch.sum(seq_lens - prefix_lens).item() + raise NotImplementedError( + "page_size > 1 and top_k > 1 are not supported." + ) + # TODO: Support page_size > 1 and top_k > 1 + # 1. Duplicate the KV cache in the last partial page for all top-k segments + # 2. Modify generate_draft_decode_kv_indices accordingly - # TODO(lmzheng): remove this device sync - extend_num_tokens = torch.sum(self.extend_lens).item() - + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) out_cache_loc, token_to_kv_pool_state_backup = ( batch.alloc_paged_token_slots_extend( prefix_lens, @@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker): batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, - self.extend_lens, - self.num_new_pages_per_topk, out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], self.topk, self.speculative_num_steps, self.page_size, - next_power_of_2(num_seqs), - next_power_of_2(self.speculative_num_steps), ) - - if self.page_size > 1 and self.topk > 1: - # Remove padded slots - out_cache_loc = out_cache_loc[ - : num_seqs * self.topk * self.speculative_num_steps - ] - batch.out_cache_loc = out_cache_loc batch.seq_lens_sum = torch.sum(batch.seq_lens).item() - batch.return_hidden_states = False spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - - # Get forward batch + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker): if self.hot_token_id is not None: topk_index = self.hot_token_id[topk_index] - out_cache_loc = out_cache_loc.reshape( - forward_batch.batch_size, self.topk, self.speculative_num_steps - ) - out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape( - self.speculative_num_steps, -1 - ) - # Return values score_list: List[torch.Tensor] = [] token_list: List[torch.Tensor] = [] @@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker): # Set inputs forward_batch.input_ids = input_ids - forward_batch.out_cache_loc = out_cache_loc[i:] + out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1) + forward_batch.out_cache_loc = out_cache_loc[ + :, self.topk * i : self.topk * (i + 1) + ].flatten() forward_batch.positions.add_(1) forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] spec_info.hidden_states = hidden_states @@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker): if vocab_mask is not None: assert spec_info.grammar is not None vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device) - # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage + # otherwise, this vocab mask will be the one from the previous extend stage # and will be applied to produce wrong results batch.sampling_info.vocab_mask = None @@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker): ] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] - if batch.return_logprob: - self.add_logprob_values(batch, res, logits_output) - # Prepare the batch for the next draft forwards. batch.forward_mode = ForwardMode.DECODE batch.spec_info = res.draft_input + if batch.return_logprob: + self.add_logprob_values(batch, res, logits_output) + return logits_output, res, model_worker_batch, can_run_cuda_graph def add_logprob_values( @@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker): logits_output = res.logits_output top_logprobs_nums = batch.top_logprobs_nums token_ids_logprobs = batch.token_ids_logprobs - accepted_indices = res.accepted_indices - assert len(accepted_indices) == len(logits_output.next_token_logits) - temperatures = batch.sampling_info.temperatures - num_draft_tokens = batch.spec_info.draft_token_num - # acceptance indices are the indices in a "flattened" batch. - # dividing it to num_draft_tokens will yield the actual batch index. - temperatures = temperatures[accepted_indices // num_draft_tokens] - logprobs = torch.nn.functional.log_softmax( - logits_output.next_token_logits / temperatures, dim=-1 + logits_output.next_token_logits, dim=-1 ) batch_next_token_ids = res.verified_id num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] @@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker): pt = 0 next_token_logprobs = logits_output.next_token_logprobs.tolist() verified_ids = batch_next_token_ids.tolist() - for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True): + for req, num_tokens in zip(batch.reqs, num_tokens_per_req): for _ in range(num_tokens): if req.return_logprob: req.output_token_logprobs_val.append(next_token_logprobs[pt]) @@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker): hidden_states: Hidden states from the target model forward next_token_ids: Next token ids generated from the target forward. """ + # Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last batch.spec_info = EagleDraftInput( hidden_states=hidden_states, verified_id=next_token_ids, @@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker): model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker): batch, self.speculative_num_steps, ) + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]: token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) hot_token_id = torch.load(token_map_path, weights_only=True) return torch.tensor(hot_token_id, dtype=torch.int32) - - -@torch.compile(dynamic=True) -def get_last_loc_large_page_size_top_k_1( - req_to_token: torch.Tensor, - req_pool_indices: torch.Tensor, - seq_lens, - speculative_num_steps: int, -): - prefix_lens = seq_lens - seq_lens = prefix_lens + speculative_num_steps - last_loc = get_last_loc( - req_to_token, - req_pool_indices, - prefix_lens, - ) - return prefix_lens, seq_lens, last_loc - - -@torch.compile(dynamic=True) -def get_last_loc_large_page_size_large_top_k( - req_to_token: torch.Tensor, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - speculative_num_steps: int, - topk: int, - page_size: int, -): - prefix_lens = seq_lens - last_page_lens = prefix_lens % page_size - num_new_pages_per_topk = ( - last_page_lens + speculative_num_steps + page_size - 1 - ) // page_size - seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * ( - page_size * topk - ) - extend_lens = seq_lens - prefix_lens - last_loc = get_last_loc( - req_to_token, - req_pool_indices, - prefix_lens, - ) - - return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens diff --git a/test/srt/test_eagle_infer_b.py b/test/srt/test_eagle_infer_b.py index 5b9df1630..f71feb15a 100644 --- a/test/srt/test_eagle_infer_b.py +++ b/test/srt/test_eagle_infer_b.py @@ -441,71 +441,5 @@ class TestEAGLEServerTriton(TestEAGLEServer): ) -class TestEAGLEServerPageSize(TestEAGLEServer): - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - "--speculative-num-steps", - 5, - "--speculative-eagle-topk", - 1, - "--speculative-num-draft-tokens", - 6, - "--mem-fraction-static", - 0.7, - "--chunked-prefill-size", - 128, - "--max-running-requests", - 8, - "--page-size", - 4, - "--attention-backend", - "flashinfer", - ], - ) - - -class TestEAGLEServerPageSizeTopk(TestEAGLEServer): - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - "--speculative-num-steps", - 5, - "--speculative-eagle-topk", - 8, - "--speculative-num-draft-tokens", - 64, - "--mem-fraction-static", - 0.7, - "--chunked-prefill-size", - 128, - "--max-running-requests", - 8, - "--page-size", - 4, - "--attention-backend", - "flashinfer", - ], - ) - - if __name__ == "__main__": unittest.main()