diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 876141083..f11de5641 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend: kv_indices_buffer, self.kv_indptr, forward_batch.positions, - num_seqs, - self.topk, self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), next_power_of_2(bs), + self.page_size, ) assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index c4192a715..19fa09818 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -789,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend: # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + self.page_size = model_runner.server_args.page_size def common_template( self, @@ -809,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend: kv_indices_buffer, self.kv_indptr, forward_batch.positions, - num_seqs, - self.topk, self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), next_power_of_2(bs), + self.page_size, ) assert forward_batch.spec_info is not None diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index cad4c1950..c688fd461 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -2,9 +2,6 @@ from __future__ import annotations """ Support attention backend for FlashMLA. - -#TODO -Enable speculative sampling in FlashMLA """ from dataclasses import dataclass diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 0a3aef9c3..970ceb999 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -784,14 +784,13 @@ class TritonMultiStepDraftBackend: kv_indices_buffer, self.kv_indptr, forward_batch.positions, - num_seqs, - self.topk, self.pool_len, kv_indices_buffer.shape[1], self.kv_indptr.shape[1], next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), next_power_of_2(bs), + self.page_size, ) for i in range(self.speculative_num_steps): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index bac310da3..1e823be10 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -294,6 +294,19 @@ class MHATokenToKVPool(KVCache): for _ in range(self.layer_num) ] + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.k_buffer + self.v_buffer], + dtype=torch.uint64, + device=self.device, + ) + self.data_strides = torch.tensor( + [ + np.prod(x.shape[1:]) * x.dtype.itemsize + for x in self.k_buffer + self.v_buffer + ], + device=self.device, + ) + def _clear_buffers(self): del self.k_buffer del self.v_buffer @@ -451,6 +464,16 @@ class MHATokenToKVPool(KVCache): self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.v_buffer[layer_id - self.start_layer][loc] = cache_v + def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): + copy_all_layer_kv_cache[(len(self.data_ptrs),)]( + self.data_ptrs, + self.data_strides, + tgt_loc, + src_loc, + len(tgt_loc), + next_power_of_2(len(tgt_loc)), + ) + @triton.jit def set_mla_kv_buffer_kernel( @@ -741,3 +764,41 @@ class DoubleSparseTokenToKVPool(KVCache): def transfer_per_layer(self, indices, flat_data, layer_id): pass + + +@triton.jit +def copy_all_layer_kv_cache( + data_ptrs, + strides, + tgt_loc_ptr, + src_loc_ptr, + num_locs, + num_locs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + + bid = tl.program_id(0) + stride = tl.load(strides + bid) + + data_ptr = tl.load(data_ptrs + bid) + data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8)) + + num_locs_offset = tl.arange(0, num_locs_upper) + tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) + src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) + + # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks + # because this copy is an inplace operation. + + num_loop = tl.cdiv(stride, BLOCK_SIZE) + for i in range(num_loop): + copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :] + value = tl.load( + data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask + ) + tl.store( + data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :], + value, + mask=mask, + ) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 8bb1222da..cedc2ee88 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -67,8 +67,6 @@ class EagleDraftInput: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None - all_padding_lens: Optional[torch.Tensor] = None - def prepare_for_extend(self, batch: ScheduleBatch): # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) @@ -93,6 +91,7 @@ class EagleDraftInput: batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.return_logprob = False + batch.return_hidden_states = False self.capture_hidden_mode = CaptureHiddenMode.LAST self.accept_length.add_(1) @@ -116,10 +115,8 @@ class EagleDraftInput: req_to_token: torch.Tensor, ): bs = self.accept_length.numel() - qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -139,7 +136,6 @@ class EagleDraftInput: kv_indices, req_to_token.size(1), ) - return kv_indices, cum_kv_seq_len, qo_indptr, None def filter_batch(self, new_indices: torch.Tensor): @@ -270,7 +266,7 @@ class EagleVerifyInput: logits_output: torch.Tensor, token_to_kv_pool_allocator: TokenToKVPoolAllocator, page_size: int, - vocab_mask: Optional[torch.Tensor] = None, + vocab_mask: Optional[torch.Tensor] = None, # For grammar ) -> torch.Tensor: """ Verify and find accepted tokens based on logits output and batch @@ -294,6 +290,14 @@ class EagleVerifyInput: ) accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + logits_output.next_token_logits, + sampling_info, + num_tokens_in_batch=self.draft_token_num, + ) + # Apply penalty if sampling_info.penalizer_orchestrator.is_required: # This is a relaxed version of penalties for speculative decoding. @@ -355,7 +359,13 @@ class EagleVerifyInput: draft_probs = torch.zeros( target_probs.shape, dtype=torch.float32, device="cuda" ) + + # coins for rejection sampling coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") + # coins for final sampling + coins_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device="cuda" + ) tree_speculative_sampling_target_only( predicts=predict, # mutable accept_index=accept_index, # mutable @@ -365,6 +375,7 @@ class EagleVerifyInput: retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), uniform_samples=coins, + # uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, threshold_single=global_server_args_dict[ @@ -387,8 +398,8 @@ class EagleVerifyInput: spec_steps=self.spec_steps, ) - new_accept_index = [] unfinished_index = [] + unfinished_accept_index = [] accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() has_finished = False @@ -396,12 +407,10 @@ class EagleVerifyInput: # Iterate every accepted token and check if req has finished after append the token # should be checked BEFORE free kv cache slots for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): - new_accept_index_ = [] for j, idx in enumerate(accept_index_row): if idx == -1: break id = predict_cpu[idx] - # if not found_finished: req.output_ids.append(id) req.check_finished() if req.finished(): @@ -410,8 +419,6 @@ class EagleVerifyInput: accept_index[i, j + 1 :] = -1 break else: - new_accept_index_.append(idx) - # update grammar state if req.grammar is not None: try: req.grammar.accept_token(id) @@ -421,50 +428,104 @@ class EagleVerifyInput: ) raise e if not req.finished(): - new_accept_index.extend(new_accept_index_) unfinished_index.append(i) + if idx == -1: + unfinished_accept_index.append(accept_index[i, :j]) + else: + unfinished_accept_index.append(accept_index[i]) req.spec_verify_ct += 1 if has_finished: accept_length = (accept_index != -1).sum(dim=1) - 1 # Free the KV cache for unaccepted tokens + # TODO: fuse them accept_index = accept_index[accept_index != -1] verified_id = predict[accept_index] evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False - if page_size != 1: - align_evict_mask_to_page_size[len(batch.seq_lens),]( - batch.seq_lens, - evict_mask, - page_size, - self.draft_token_num, - next_power_of_2(self.draft_token_num), - ) + if page_size == 1: + # TODO: boolean array index leads to a device sync. Remove it. + token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) + else: + if self.topk == 1: + # Only evict full empty page. Do not evict partial empty page + align_evict_mask_to_page_size[len(batch.seq_lens),]( + batch.seq_lens, + evict_mask, + page_size, + self.draft_token_num, + next_power_of_2(self.draft_token_num), + ) + token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) + else: + # Shift the accepted tokens to the beginning. + # Only evict the last part + src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc( + batch.seq_lens, + batch.out_cache_loc, + accept_index, + accept_length, + self.draft_token_num, + page_size, + ) + to_free_slots = torch.empty( + (to_free_num_slots.sum().item(),), + dtype=torch.int64, + device=to_free_num_slots.device, + ) - token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) + # out_cache_loc: [0 1 2, 3 4 5, 6 7 8] + # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1] + # tgt_cache_loc: [0 1 , 3 4 , 6 ] + # to_free_slots: [ 2, 5, 7 8] + # to_free_slots also needs to be page-aligned without the first partial page + # + # split each row of out_cache_loc into two parts. + # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1 + # 2. the second part goes to to_free_slots. + get_target_cache_loc[(bs,)]( + tgt_cache_loc, + to_free_slots, + accept_length, + to_free_num_slots, + batch.out_cache_loc, + self.draft_token_num, + next_power_of_2(self.draft_token_num), + next_power_of_2(bs), + ) + + # Free the kv cache + token_to_kv_pool_allocator.free(to_free_slots) + + # Copy the kv cache + batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache( + tgt_cache_loc, src_cache_loc + ) # Construct EagleVerifyOutput if not has_finished: - batch.out_cache_loc = batch.out_cache_loc[accept_index] - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + accept_length + 1, - batch.out_cache_loc, - batch.req_to_token_pool.req_to_token.shape[1], - next_power_of_2(bs), - ) + if page_size == 1 or self.topk == 1: + batch.out_cache_loc = batch.out_cache_loc[accept_index] + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + else: + batch.out_cache_loc = tgt_cache_loc batch.seq_lens.add_(accept_length + 1) - accept_length_cpu = accept_length.tolist() draft_input = EagleDraftInput() draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] draft_input.verified_id = verified_id draft_input.accept_length = accept_length - draft_input.accept_length_cpu = accept_length_cpu + draft_input.accept_length_cpu = accept_length.tolist() draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices @@ -472,47 +533,66 @@ class EagleVerifyInput: draft_input=draft_input, logits_output=logits_output, verified_id=verified_id, - accept_length_per_req_cpu=accept_length_cpu, + accept_length_per_req_cpu=draft_input.accept_length_cpu, accepted_indices=accept_index, ) else: - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + accept_length + 1, - batch.out_cache_loc[accept_index], - batch.req_to_token_pool.req_to_token.shape[1], - next_power_of_2(bs), - ) - batch.seq_lens.add_(accept_length + 1) - accept_length_cpu = accept_length.tolist() + if page_size == 1 or self.topk == 1: + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) + accept_length_cpu = accept_length.tolist() draft_input = EagleDraftInput() - if len(new_accept_index) > 0: - new_accept_index = torch.tensor(new_accept_index, device="cuda") - unfinished_index_device = torch.tensor(unfinished_index, device="cuda") - draft_input.hidden_states = batch.spec_info.hidden_states[ - new_accept_index - ] - draft_input.verified_id = predict[new_accept_index] - draft_input.accept_length_cpu = [ + if len(unfinished_accept_index) > 0: + unfinished_accept_index = torch.cat(unfinished_accept_index) + unfinished_index_device = torch.tensor( + unfinished_index, dtype=torch.int64, device=predict.device + ) + draft_input_accept_length_cpu = [ accept_length_cpu[i] for i in unfinished_index ] - draft_input.accept_length = accept_length[unfinished_index_device] - if has_finished: - draft_input.seq_lens_for_draft_extend = batch.seq_lens[ - unfinished_index_device - ] - draft_input.req_pool_indices_for_draft_extend = ( - batch.req_pool_indices[unfinished_index_device] - ) + if page_size == 1 or self.topk == 1: + batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index] else: - draft_input.seq_lens_for_draft_extend = batch.seq_lens - draft_input.req_pool_indices_for_draft_extend = ( - batch.req_pool_indices + batch.out_cache_loc = torch.empty( + len(unfinished_index) + sum(draft_input_accept_length_cpu), + dtype=torch.int64, + device=predict.device, ) - batch.out_cache_loc = batch.out_cache_loc[new_accept_index] + accept_length_filter = create_accept_length_filter( + accept_length, + unfinished_index_device, + batch.seq_lens, + ) + filter_finished_cache_loc_kernel[(bs,)]( + batch.out_cache_loc, + tgt_cache_loc, + accept_length, + accept_length_filter, + next_power_of_2(bs), + next_power_of_2(self.draft_token_num), + ) + + draft_input.hidden_states = batch.spec_info.hidden_states[ + unfinished_accept_index + ] + draft_input.verified_id = predict[unfinished_accept_index] + draft_input.accept_length_cpu = draft_input_accept_length_cpu + draft_input.accept_length = accept_length[unfinished_index_device] + draft_input.seq_lens_for_draft_extend = batch.seq_lens[ + unfinished_index_device + ] + draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ + unfinished_index_device + ] return EagleVerifyOutput( draft_input=draft_input, @@ -589,36 +669,75 @@ def assign_draft_cache_locs( req_pool_indices, req_to_token, seq_lens, + extend_lens, + num_new_pages_per_topk, out_cache_loc, pool_len: tl.constexpr, topk: tl.constexpr, speculative_num_steps: tl.constexpr, page_size: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, ): - BLOCK_SIZE: tl.constexpr = 32 + BLOCK_SIZE: tl.constexpr = 128 pid = tl.program_id(axis=0) - kv_start = tl.load(seq_lens + pid) if page_size == 1 or topk == 1: - kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps + copy_len = topk * speculative_num_steps out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps else: - prefix_len = tl.load(seq_lens + pid) - last_page_len = prefix_len % page_size - num_new_page = ( - last_page_len + speculative_num_steps + page_size - 1 - ) // page_size - kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk) + bs_offset = tl.arange(0, bs_upper) + copy_len = tl.load(extend_lens + pid) + cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid)) + out_cache_ptr = out_cache_loc + cum_copy_len + # Part 1: Copy from out_cache_loc to req_to_token + kv_start = tl.load(seq_lens + pid) token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len - - num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) + num_loop = tl.cdiv(copy_len, BLOCK_SIZE) for i in range(num_loop): - save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start - load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = save_offset < kv_end - data = tl.load(out_cache_ptr + load_offset, mask=mask) - tl.store(token_pool + save_offset, data, mask=mask) + copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = copy_offset < copy_len + data = tl.load(out_cache_ptr + copy_offset, mask=mask) + tl.store(token_pool + kv_start + copy_offset, data, mask=mask) + + if page_size == 1 or topk == 1: + return + + # Part 2: Copy the indices for the last partial page + prefix_len = tl.load(seq_lens + pid) + last_page_len = prefix_len % page_size + offsets = tl.arange(0, page_size) + mask = offsets < last_page_len + num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) + prefix_base = token_pool + prefix_len - last_page_len + + for topk_id in range(topk): + value = tl.load(prefix_base + offsets, mask=mask) + tl.store( + prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets, + value, + mask=mask, + ) + + # Part 3: Remove the padding in out_cache_loc + iter_offest = tl.arange(0, iter_upper) + for topk_id in range(topk): + indices = tl.load( + prefix_base + + topk_id * num_new_pages_per_topk_ * page_size + + last_page_len + + iter_offest, + mask=iter_offest < speculative_num_steps, + ) + tl.store( + out_cache_loc + + pid * topk * speculative_num_steps + + topk_id * speculative_num_steps + + iter_offest, + indices, + mask=iter_offest < speculative_num_steps, + ) @triton.jit @@ -629,20 +748,23 @@ def generate_draft_decode_kv_indices( kv_indices, kv_indptr, positions, - num_seqs: tl.constexpr, - topk: tl.constexpr, pool_len: tl.constexpr, kv_indices_stride: tl.constexpr, kv_indptr_stride: tl.constexpr, bs_upper: tl.constexpr, iter_upper: tl.constexpr, num_tokens_upper: tl.constexpr, + page_size: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 128 iters = tl.program_id(axis=0) bid = tl.program_id(axis=1) topk_id = tl.program_id(axis=2) + num_steps = tl.num_programs(axis=0) + num_seqs = tl.num_programs(axis=1) + topk = tl.num_programs(axis=2) + kv_indices += kv_indices_stride * iters kv_indptr += kv_indptr_stride * iters iters += 1 @@ -652,6 +774,7 @@ def generate_draft_decode_kv_indices( seq_len = tl.load(paged_kernel_lens + bid) cum_seq_len = tl.sum(seq_lens) + # Update kv_indices kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) kv_ptr = kv_indices + kv_offset token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len @@ -665,10 +788,26 @@ def generate_draft_decode_kv_indices( kv_offset += BLOCK_SIZE extend_offset = tl.arange(0, iter_upper) - extend_data = tl.load( - token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, - mask=extend_offset < iters, - ) + if page_size == 1 or topk == 1: + extend_data = tl.load( + token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper), + mask=extend_offset < iters, + ) + else: + prefix_len = seq_len + last_page_len = prefix_len % page_size + num_new_pages_per_topk = ( + last_page_len + num_steps + page_size - 1 + ) // page_size + prefix_base = seq_len // page_size * page_size + start = ( + prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len + ) + extend_data = tl.load( + token_pool_ptr + start + extend_offset, + mask=extend_offset < iters, + ) + tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) # Update kv_indptr @@ -707,6 +846,116 @@ def align_evict_mask_to_page_size( tl.store(evict_mask + bid * num_draft_tokens + i, False) +@triton.jit +def get_target_cache_loc( + tgt_cache_loc, + to_free_slots, + accept_length, + to_free_num_slots, + out_cache_loc, + num_verify_tokens: tl.constexpr, + num_verify_tokens_upper: tl.constexpr, + bs_upper: tl.constexpr, +): + bid = tl.program_id(axis=0) + offset = tl.arange(0, num_verify_tokens_upper) + bs_offset = tl.arange(0, bs_upper) + + # write the first part to tgt_cache_loc + accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) + tgt_cache_loc_start = tl.sum(accept_len_all) + bid + copy_len = tl.load(accept_length + bid) + 1 + out_cache_loc_row = tl.load( + out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len + ) + tl.store( + tgt_cache_loc + tgt_cache_loc_start + offset, + out_cache_loc_row, + mask=offset < copy_len, + ) + + # write the second part to to_free_num_pages + to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) + to_free_num_slots_cur = tl.load(to_free_num_slots + bid) + out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur + to_free_slots_start = tl.sum(to_free_num_slots_all) + + copy_len = to_free_num_slots_cur + out_cache_loc_row = tl.load( + out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset, + mask=offset < copy_len, + ) + tl.store( + to_free_slots + to_free_slots_start + offset, + out_cache_loc_row, + mask=offset < copy_len, + ) + + +@torch.compile(dynamic=True) +def get_src_tgt_cache_loc( + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + accept_index: torch.Tensor, + accept_length: torch.Tensor, + draft_token_num: int, + page_size: int, +): + src_cache_loc = out_cache_loc[accept_index] + tgt_cache_loc = torch.empty_like(src_cache_loc) + extended_len = seq_lens + draft_token_num + keep_len = torch.minimum( + (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size, + extended_len, + ) + to_free_num_slots = extended_len - keep_len + return src_cache_loc, tgt_cache_loc, to_free_num_slots + + +@triton.jit +def filter_finished_cache_loc_kernel( + out_cache_loc, + tgt_cache_loc, + accept_length, + accept_length_filter, + bs_upper: tl.constexpr, + num_verify_tokens_upper: tl.constexpr, +): + bid = tl.program_id(0) + bs_offset = tl.arange(0, bs_upper) + + accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) + old_start = tl.sum(accept_length_all) + bid + + accept_length_filter_all = tl.load( + accept_length_filter + bs_offset, mask=bs_offset < bid + ) + new_start = tl.sum(accept_length_filter_all) + + copy_len = tl.load(accept_length_filter + bid) + copy_offset = tl.arange(0, num_verify_tokens_upper) + value = tl.load( + tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len + ) + tl.store( + out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len + ) + + +@torch.compile(dynamic=True) +def create_accept_length_filter( + accept_length: torch.Tensor, + unfinished_index_device: torch.Tensor, + seq_lens: torch.Tensor, +): + accept_length_filter = torch.zeros_like(accept_length) + accept_length_filter[unfinished_index_device] = ( + accept_length[unfinished_index_device] + 1 + ) + seq_lens.add_(accept_length + 1) + return accept_length_filter + + @torch.compile(dynamic=True) def select_top_k_tokens( i: int, @@ -756,6 +1005,16 @@ def select_top_k_tokens( return input_ids, hidden_states, scores, tree_info +def fast_topk_torch(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + max_value, max_index = torch.max(values, dim=dim) + return max_value.unsqueeze(1), max_index.unsqueeze(1) + else: + # Use topk for efficiency with larger k values + return torch.topk(values, topk, dim=dim) + + def _generate_simulated_accept_index( accept_index, predict, @@ -765,15 +1024,35 @@ def _generate_simulated_accept_index( spec_steps, ): simulate_acc_len_float = float(simulate_acc_len) - simulated_values = torch.normal( - mean=simulate_acc_len_float, - std=1.0, - size=(1,), - device="cpu", - ) - # clamp simulated values to be between 1 and self.spec_steps - simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps) - simulate_acc_len = int(simulated_values.round().item()) + if SIMULATE_ACC_METHOD == "multinomial": + simulated_values = torch.normal( + mean=simulate_acc_len_float, + std=1.0, + size=(1,), + device="cpu", + ) + # clamp simulated values to be between 1 and self.spec_steps + simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) + simulate_acc_len = int(simulated_values.round().item()) + elif SIMULATE_ACC_METHOD == "match-expected": + # multinomial sampling does not match the expected length + # we keep it for the sake of compatibility of existing tests + # but it's better to use "match-expected" for the cases that need to + # match the expected length, One caveat is that this will only sample + # either round down or round up of the expected length + simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float)) + lower = int(simulate_acc_len_float // 1) + upper = lower + 1 if lower < spec_steps + 1 else lower + if lower == upper: + simulate_acc_len = lower + else: + weight_upper = simulate_acc_len_float - lower + weight_lower = 1.0 - weight_upper + probs = torch.tensor([weight_lower, weight_upper], device="cpu") + sampled_index = torch.multinomial(probs, num_samples=1) + simulate_acc_len = lower if sampled_index == 0 else upper + else: + raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}") accept_indx_first_col = accept_index[:, 0].view(-1, 1) sim_accept_index = torch.full( @@ -864,9 +1143,9 @@ def generate_token_bitmask( """ Generate the logit mask for structured output. Draft model's token can be either valid or invalid with respect to the grammar. - We need to perform DFS to figure out: - 1. which tokens are accepted by the grammar - 2. what is the corresponding logit mask. + We need to perform DFS to + 1. figure out which tokens are accepted by the grammar. + 2. if so, what is the corresponding logit mask. """ num_draft_tokens = draft_tokens_cpu.shape[-1] @@ -883,6 +1162,7 @@ def generate_token_bitmask( device="cpu", ) grammar = req.grammar + s = time.perf_counter() traverse_tree( retrieve_next_token_cpu[i], retrieve_next_sibling_cpu[i], @@ -892,6 +1172,12 @@ def generate_token_bitmask( i * num_draft_tokens : (i + 1) * num_draft_tokens ], ) + tree_traverse_time = time.perf_counter() - s + if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD: + logger.warning( + f"Bit mask generation took {tree_traverse_time} seconds with " + f"grammar: {req.grammar}" + ) verify_input.grammar = grammar return allocate_token_bitmask diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 0597ad4e0..e42515dce 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import ( EagleVerifyInput, EagleVerifyOutput, assign_draft_cache_locs, + fast_topk, generate_token_bitmask, select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda +from sglang.srt.utils import ( + empty_context, + get_available_gpu_memory, + is_cuda, + next_power_of_2, +) if is_cuda(): from sgl_kernel import segment_packbits @@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker): self.init_attention_backend() self.init_cuda_graphs() + # Some dummy tensors + self.num_new_pages_per_topk = torch.empty( + (), dtype=torch.int64, device=self.device + ) + self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device) + def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners if self.server_args.attention_backend == "flashinfer": @@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker): self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." + f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) # Capture extend @@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker): ) after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." + f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) @property @@ -290,7 +302,6 @@ class EAGLEWorker(TpModelWorker): A tuple of the final logit output of the target model, next tokens accepted, the batch id (used for overlap schedule), and number of accepted tokens. """ - if batch.forward_mode.is_decode(): with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) @@ -366,14 +377,21 @@ class EAGLEWorker(TpModelWorker): ) # Allocate cache locations + # Layout of the out_cache_loc + # [ topk 0 ] [ topk 1 ] + # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2] if self.page_size == 1: out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( - num_seqs * self.topk * self.speculative_num_steps, backup_state=True + num_seqs * self.speculative_num_steps * self.topk, backup_state=True ) else: if self.topk == 1: - prefix_lens = batch.seq_lens - seq_lens = prefix_lens + self.speculative_num_steps + prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + batch.seq_lens, + self.speculative_num_steps, + ) extend_num_tokens = num_seqs * self.speculative_num_steps else: # In this case, the last partial page needs to be duplicated. @@ -386,29 +404,33 @@ class EAGLEWorker(TpModelWorker): # "x" means speculative draft tokens # "." means padded tokens - # TODO: fuse these ops - prefix_lens = batch.seq_lens - last_page_lens = prefix_lens % self.page_size - num_new_pages = ( - last_page_lens + self.speculative_num_steps + self.page_size - 1 - ) // self.page_size - seq_lens = ( - prefix_lens // self.page_size * self.page_size - + num_new_pages * (self.page_size * self.topk) - ) - extend_num_tokens = torch.sum(seq_lens - prefix_lens).item() - raise NotImplementedError( - "page_size > 1 and top_k > 1 are not supported." - ) - # TODO: Support page_size > 1 and top_k > 1 - # 1. Duplicate the KV cache in the last partial page for all top-k segments - # 2. Modify generate_draft_decode_kv_indices accordingly + # TODO(lmzheng): The current implementation is still a fake support + # for page size > 1. In the `assign_draft_cache_locs` below, + # we directly move the indices instead of the real kv cache. + # This only works when the kernel backend runs with page size = 1. + # If the kernel backend runs with page size > 1, we need to + # duplicate the real KV cache. The overhead of duplicating KV + # cache seems okay because the draft KV cache only has one layer. + # see a related copy operation in MHATokenToKVPool::move_kv_cache. + + ( + prefix_lens, + seq_lens, + last_loc, + self.num_new_pages_per_topk, + self.extend_lens, + ) = get_last_loc_large_page_size_large_top_k( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + batch.seq_lens, + self.speculative_num_steps, + self.topk, + self.page_size, + ) + + # TODO(lmzheng): remove this device sync + extend_num_tokens = torch.sum(self.extend_lens).item() - last_loc = get_last_loc( - batch.req_to_token_pool.req_to_token, - batch.req_pool_indices, - prefix_lens, - ) out_cache_loc, token_to_kv_pool_state_backup = ( batch.alloc_paged_token_slots_extend( prefix_lens, @@ -423,19 +445,31 @@ class EAGLEWorker(TpModelWorker): batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, + self.extend_lens, + self.num_new_pages_per_topk, out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], self.topk, self.speculative_num_steps, self.page_size, + next_power_of_2(num_seqs), + next_power_of_2(self.speculative_num_steps), ) + + if self.page_size > 1 and self.topk > 1: + # Remove padded slots + out_cache_loc = out_cache_loc[ + : num_seqs * self.topk * self.speculative_num_steps + ] + batch.out_cache_loc = out_cache_loc batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + batch.return_hidden_states = False spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - batch.return_hidden_states = False + + # Get forward batch model_worker_batch = batch.get_model_worker_batch() - assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -449,9 +483,6 @@ class EAGLEWorker(TpModelWorker): else: # Initialize attention backend self.draft_attn_backend.init_forward_metadata(forward_batch) - forward_batch = ForwardBatch.init_new( - model_worker_batch, self.draft_model_runner - ) # Run forward steps score_list, token_list, parents_list = self.draft_forward(forward_batch) @@ -504,6 +535,13 @@ class EAGLEWorker(TpModelWorker): if self.hot_token_id is not None: topk_index = self.hot_token_id[topk_index] + out_cache_loc = out_cache_loc.reshape( + forward_batch.batch_size, self.topk, self.speculative_num_steps + ) + out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape( + self.speculative_num_steps, -1 + ) + # Return values score_list: List[torch.Tensor] = [] token_list: List[torch.Tensor] = [] @@ -525,10 +563,7 @@ class EAGLEWorker(TpModelWorker): # Set inputs forward_batch.input_ids = input_ids - out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1) - forward_batch.out_cache_loc = out_cache_loc[ - :, self.topk * i : self.topk * (i + 1) - ].flatten() + forward_batch.out_cache_loc = out_cache_loc[i] forward_batch.positions.add_(1) forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] spec_info.hidden_states = hidden_states @@ -586,7 +621,7 @@ class EAGLEWorker(TpModelWorker): if vocab_mask is not None: assert spec_info.grammar is not None vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device) - # otherwise, this vocab mask will be the one from the previous extend stage + # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage # and will be applied to produce wrong results batch.sampling_info.vocab_mask = None @@ -607,13 +642,13 @@ class EAGLEWorker(TpModelWorker): ] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] + if batch.return_logprob: + self.add_logprob_values(batch, res, logits_output) + # Prepare the batch for the next draft forwards. batch.forward_mode = ForwardMode.DECODE batch.spec_info = res.draft_input - if batch.return_logprob: - self.add_logprob_values(batch, res, logits_output) - return logits_output, res, model_worker_batch, can_run_cuda_graph def add_logprob_values( @@ -626,8 +661,16 @@ class EAGLEWorker(TpModelWorker): logits_output = res.logits_output top_logprobs_nums = batch.top_logprobs_nums token_ids_logprobs = batch.token_ids_logprobs + accepted_indices = res.accepted_indices + assert len(accepted_indices) == len(logits_output.next_token_logits) + temperatures = batch.sampling_info.temperatures + num_draft_tokens = batch.spec_info.draft_token_num + # acceptance indices are the indices in a "flattened" batch. + # dividing it to num_draft_tokens will yield the actual batch index. + temperatures = temperatures[accepted_indices // num_draft_tokens] + logprobs = torch.nn.functional.log_softmax( - logits_output.next_token_logits, dim=-1 + logits_output.next_token_logits / temperatures, dim=-1 ) batch_next_token_ids = res.verified_id num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] @@ -662,7 +705,7 @@ class EAGLEWorker(TpModelWorker): pt = 0 next_token_logprobs = logits_output.next_token_logprobs.tolist() verified_ids = batch_next_token_ids.tolist() - for req, num_tokens in zip(batch.reqs, num_tokens_per_req): + for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True): for _ in range(num_tokens): if req.return_logprob: req.output_token_logprobs_val.append(next_token_logprobs[pt]) @@ -690,7 +733,6 @@ class EAGLEWorker(TpModelWorker): hidden_states: Hidden states from the target model forward next_token_ids: Next token ids generated from the target forward. """ - # Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last batch.spec_info = EagleDraftInput( hidden_states=hidden_states, verified_id=next_token_ids, @@ -701,7 +743,6 @@ class EAGLEWorker(TpModelWorker): model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) - assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -724,9 +765,7 @@ class EAGLEWorker(TpModelWorker): batch, self.speculative_num_steps, ) - batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() - assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -790,3 +829,47 @@ def load_token_map(token_map_path: str) -> List[int]: token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) hot_token_id = torch.load(token_map_path, weights_only=True) return torch.tensor(hot_token_id, dtype=torch.int32) + + +@torch.compile(dynamic=True) +def get_last_loc_large_page_size_top_k_1( + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens, + speculative_num_steps: int, +): + prefix_lens = seq_lens + seq_lens = prefix_lens + speculative_num_steps + last_loc = get_last_loc( + req_to_token, + req_pool_indices, + prefix_lens, + ) + return prefix_lens, seq_lens, last_loc + + +@torch.compile(dynamic=True) +def get_last_loc_large_page_size_large_top_k( + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + speculative_num_steps: int, + topk: int, + page_size: int, +): + prefix_lens = seq_lens + last_page_lens = prefix_lens % page_size + num_new_pages_per_topk = ( + last_page_lens + speculative_num_steps + page_size - 1 + ) // page_size + seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * ( + page_size * topk + ) + extend_lens = seq_lens - prefix_lens + last_loc = get_last_loc( + req_to_token, + req_pool_indices, + prefix_lens, + ) + + return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens diff --git a/test/srt/test_eagle_infer_b.py b/test/srt/test_eagle_infer_b.py index f71feb15a..5b9df1630 100644 --- a/test/srt/test_eagle_infer_b.py +++ b/test/srt/test_eagle_infer_b.py @@ -441,5 +441,71 @@ class TestEAGLEServerTriton(TestEAGLEServer): ) +class TestEAGLEServerPageSize(TestEAGLEServer): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + 5, + "--speculative-eagle-topk", + 1, + "--speculative-num-draft-tokens", + 6, + "--mem-fraction-static", + 0.7, + "--chunked-prefill-size", + 128, + "--max-running-requests", + 8, + "--page-size", + 4, + "--attention-backend", + "flashinfer", + ], + ) + + +class TestEAGLEServerPageSizeTopk(TestEAGLEServer): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + 5, + "--speculative-eagle-topk", + 8, + "--speculative-num-draft-tokens", + 64, + "--mem-fraction-static", + 0.7, + "--chunked-prefill-size", + 128, + "--max-running-requests", + 8, + "--page-size", + 4, + "--attention-backend", + "flashinfer", + ], + ) + + if __name__ == "__main__": unittest.main()