From 0b2aa8a70ca503abbde4905cf0f857933fbdc928 Mon Sep 17 00:00:00 2001 From: Zhang Junda Date: Thu, 2 Oct 2025 10:51:25 +0800 Subject: [PATCH] Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720) Co-authored-by: hnyls2002 --- .../decode_schedule_batch_mixin.py | 1 + python/sglang/srt/managers/schedule_batch.py | 42 ++++++++++++++----- python/sglang/srt/mem_cache/allocator.py | 28 ++++--------- .../sglang/srt/mem_cache/allocator_ascend.py | 10 +++-- python/sglang/srt/speculative/eagle_info.py | 25 ++++++++--- python/sglang/srt/speculative/eagle_worker.py | 19 +++++++-- python/sglang/srt/speculative/ngram_utils.py | 10 +++-- python/sglang/srt/utils.py | 24 +++++++++++ 8 files changed, 115 insertions(+), 44 deletions(-) diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 277c84f9d..6812397f5 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin: req_pool_indices, dtype=torch.int64, device=self.device ) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) + self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) self.orig_seq_lens = torch.tensor( seq_lens, dtype=torch.int32, device=self.device ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 32df8e26e..b6329cb28 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): token_type_ids: torch.Tensor = None # shape: [b], int64 req_pool_indices: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64 + seq_lens_cpu: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache out_cache_loc: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64 @@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def alloc_paged_token_slots_extend( self, prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, extend_num_tokens: int, backup_state: bool = False, @@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Over estimate the number of tokens: assume each request needs a new page. num_tokens = ( extend_num_tokens - + len(seq_lens) * self.token_to_kv_pool_allocator.page_size + + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size ) self._evict_tree_cache_if_needed(num_tokens) @@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): state = self.token_to_kv_pool_allocator.backup_state() out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( - prefix_lens, seq_lens, last_loc, extend_num_tokens + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, ) if out_cache_loc is None: error_msg = ( @@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def alloc_paged_token_slots_decode( self, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, backup_state: bool = False, ): @@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if backup_state: state = self.token_to_kv_pool_allocator.backup_state() - out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) + out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode( + seq_lens, seq_lens_cpu, last_loc + ) if out_cache_loc is None: error_msg = ( f"Decode out of memory. Try to lower your batch size.\n" @@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) + self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) if not decoder_out_cache_loc: self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to( @@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) + seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64) orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) prefix_lens_tensor = torch.tensor( prefix_lens, dtype=torch.int64, device=self.device ) + prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64) token_type_ids_tensor = None if len(token_type_ids) > 0: @@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): prefix_lens_tensor, ) out_cache_loc = self.alloc_paged_token_slots_extend( - prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens + prefix_lens_tensor, + prefix_lens_cpu_tensor, + seq_lens_tensor, + seq_lens_cpu_tensor, + last_loc, + extend_num_tokens, ) # Set fields self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor self.seq_lens = seq_lens_tensor + self.seq_lens_cpu = seq_lens_cpu_tensor self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc self.input_embeds = ( @@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) retracted_reqs = [] - seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True while first_iter or ( not self.check_decode_mem(selected_indices=sorted_indices) @@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): req = self.reqs[idx] - seq_lens_cpu = self.seq_lens.cpu().numpy() + seq_lens_cpu = self.seq_lens_cpu.numpy() if server_args.disaggregation_mode == "decode": req.offload_kv_cache( @@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) + self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) @@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if self.enable_overlap: # Do not use in-place operations in the overlap mode self.seq_lens = self.seq_lens + 1 + self.seq_lens_cpu = self.seq_lens_cpu + 1 self.orig_seq_lens = self.orig_seq_lens + 1 else: # A faster in-place version self.seq_lens.add_(1) + self.seq_lens_cpu.add_(1) self.orig_seq_lens.add_(1) self.seq_lens_sum += bs @@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.req_pool_indices, self.seq_lens - 2 ] self.out_cache_loc = self.alloc_paged_token_slots_decode( - self.seq_lens, last_loc + self.seq_lens, self.seq_lens_cpu, last_loc ) self.req_to_token_pool.write( @@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device] + self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None self.seq_lens_sum = self.seq_lens.sum().item() @@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) + self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum @@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.sampling_info.grammars = None seq_lens_cpu = ( - seq_lens_cpu_cache - if seq_lens_cpu_cache is not None - else self.seq_lens.cpu() + seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu ) global bid diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 497331673..e3314ab60 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -27,7 +27,7 @@ import triton import triton.language as tl from sglang.srt.mem_cache.memory_pool import SWAKVPool -from sglang.srt.utils import get_bool_env_var, next_power_of_2 +from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2 if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import KVCache @@ -294,7 +294,6 @@ def alloc_extend_kernel( last_loc_ptr, free_page_ptr, out_indices, - ret_values, bs_upper: tl.constexpr, page_size: tl.constexpr, max_num_extend_tokens: tl.constexpr, @@ -323,13 +322,6 @@ def alloc_extend_kernel( sum_num_new_pages = tl.sum(num_new_pages) new_page_start_loc = sum_num_new_pages - num_page_start_loc_self - # Return value - if pid == tl.num_programs(0) - 1: - merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to( - tl.int64 - ) - tl.store(ret_values, merged_value) - # Part 1: fill the old partial page last_loc = tl.load(last_loc_ptr + pid) num_part1 = ( @@ -381,7 +373,6 @@ def alloc_decode_kernel( last_loc_ptr, free_page_ptr, out_indices, - ret_values, bs_upper: tl.constexpr, page_size: tl.constexpr, ): @@ -404,10 +395,6 @@ def alloc_decode_kernel( sum_num_new_pages = tl.sum(num_new_pages) new_page_start_loc = sum_num_new_pages - num_page_start_loc_self - # Return value - if pid == tl.num_programs(0) - 1: - tl.store(ret_values, sum_num_new_pages) - if num_page_start_loc_self == 0: last_loc = tl.load(last_loc_ptr + pid) tl.store(out_indices + pid, last_loc + 1) @@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): super().__init__(size, page_size, dtype, device, kvcache, need_sort) self.num_pages = size // page_size self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") - self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.clear() @@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): def alloc_extend( self, prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, extend_num_tokens: int, ): @@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): last_loc, self.free_pages, out_indices, - self.ret_values, next_power_of_2(bs), self.page_size, self.seen_max_num_extend_tokens_next_power_of_2, @@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) - merged_value = self.ret_values.item() - num_new_pages = merged_value >> 32 + num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size) if num_new_pages > len(self.free_pages): return None @@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): def alloc_decode( self, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, ): if self.debug_mode: @@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): last_loc, self.free_pages, out_indices, - self.ret_values, next_power_of_2(bs), self.page_size, ) @@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) - num_new_pages = self.ret_values.item() + num_new_pages = get_num_new_pages( + seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True + ) if num_new_pages > len(self.free_pages): return None diff --git a/python/sglang/srt/mem_cache/allocator_ascend.py b/python/sglang/srt/mem_cache/allocator_ascend.py index 2af138a6c..546e3b45a 100644 --- a/python/sglang/srt/mem_cache/allocator_ascend.py +++ b/python/sglang/srt/mem_cache/allocator_ascend.py @@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): def alloc_extend( self, prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, extend_num_tokens: int, ): @@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): num_new_pages = ( ( - (seq_lens + self.page_size - 1) // self.page_size - - (prefix_lens + self.page_size - 1) // self.page_size + (seq_lens_cpu + self.page_size - 1) // self.page_size + - (prefix_lens_cpu + self.page_size - 1) // self.page_size ) .sum() .item() @@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): def alloc_decode( self, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, ): if self.debug_mode: @@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ) need_new_pages = (seq_lens % self.page_size == 1).int() - num_new_pages = need_new_pages.sum().item() + need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int() + num_new_pages = need_new_pages_cpu.sum().item() if num_new_pages > len(self.free_pages): self.merge_and_sort_free() diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 18a787256..6ab1499f9 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput): end_offset = batch.seq_lens + self.draft_token_num else: prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num last_loc = get_last_loc( batch.req_to_token_pool.req_to_token, batch.req_pool_indices, prefix_lens, ) batch.out_cache_loc = batch.alloc_paged_token_slots_extend( - prefix_lens, end_offset, last_loc, len(batch.input_ids) + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), ) self.last_loc = last_loc @@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput): verified_id = predict[accept_index] evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False + accept_length_cpu = accept_length.cpu() + accept_length_list = accept_length_cpu.tolist() if page_size == 1: # TODO: boolean array index leads to a device sync. Remove it. @@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput): else: batch.out_cache_loc = tgt_cache_loc batch.seq_lens.add_(accept_length + 1) + batch.seq_lens_cpu.add_(accept_length_cpu + 1) draft_input = EagleDraftInput( hidden_states=batch.spec_info.hidden_states[accept_index], verified_id=verified_id, accept_length=accept_length, - accept_length_cpu=accept_length.tolist(), + accept_length_cpu=accept_length_list, seq_lens_for_draft_extend=batch.seq_lens, + seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu, req_pool_indices_for_draft_extend=batch.req_pool_indices, ) @@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput): next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) + batch.seq_lens_cpu.add_(accept_length_cpu + 1) - accept_length_cpu = accept_length.tolist() if len(unfinished_accept_index) > 0: unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_index_device = torch.tensor( unfinished_index, dtype=torch.int64, device=predict.device ) draft_input_accept_length_cpu = [ - accept_length_cpu[i] for i in unfinished_index + accept_length_list[i] for i in unfinished_index ] if page_size == 1 or self.topk == 1: batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index] @@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput): unfinished_index_device, batch.seq_lens, ) + batch.seq_lens_cpu.add_(accept_length_cpu + 1) filter_finished_cache_loc_kernel[(bs,)]( batch.out_cache_loc, tgt_cache_loc, @@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput): accept_length_cpu=draft_input_accept_length_cpu, accept_length=accept_length[unfinished_index_device], seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device], + seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index], req_pool_indices_for_draft_extend=batch.req_pool_indices[ unfinished_index_device ], @@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput): draft_input=draft_input, logits_output=logits_output, verified_id=verified_id, - accept_length_per_req_cpu=accept_length_cpu, + accept_length_per_req_cpu=accept_length_list, accepted_indices=accept_index, ) @@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput): # Inputs for draft extend # shape: (b,) seq_lens_for_draft_extend: torch.Tensor = None + seq_lens_for_draft_extend_cpu: torch.Tensor = None req_pool_indices_for_draft_extend: torch.Tensor = None def __post_init__(self): @@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput): batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.return_logprob = False batch.return_hidden_states = False diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 1782d6da0..f115f3eb8 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker): batch.seq_lens, self.speculative_num_steps, ) + prefix_lens_cpu = batch.seq_lens_cpu + seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps extend_num_tokens = num_seqs * self.speculative_num_steps else: # In this case, the last partial page needs to be duplicated. @@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker): self.topk, self.page_size, ) - - # TODO(lmzheng): remove this device sync - extend_num_tokens = torch.sum(self.extend_lens).item() + prefix_lens_cpu = batch.seq_lens_cpu + last_page_lens = prefix_lens_cpu % self.page_size + num_new_pages_per_topk = ( + last_page_lens + self.speculative_num_steps + self.page_size - 1 + ) // self.page_size + seq_lens_cpu = ( + prefix_lens_cpu // self.page_size * self.page_size + + num_new_pages_per_topk * (self.page_size * self.topk) + ) + extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item() out_cache_loc, token_to_kv_pool_state_backup = ( batch.alloc_paged_token_slots_extend( prefix_lens, + prefix_lens_cpu, seq_lens, + seq_lens_cpu, last_loc, extend_num_tokens, backup_state=True, @@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker): assert isinstance(batch.spec_info, EagleDraftInput) # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() + seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices accept_length_backup = batch.spec_info.accept_length return_logprob_backup = batch.return_logprob @@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker): ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup + batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup batch.spec_info.accept_length = accept_length_backup batch.return_logprob = return_logprob_backup diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index ad4a332bd..79d66a047 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput): batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) end_offset = batch.seq_lens + self.draft_token_num else: + # TODO(lsyin): add prefix lens cpu here to support page size > 1 prefix_lens = batch.seq_lens end_offset = prefix_lens + self.draft_token_num last_loc = get_last_loc( @@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput): self._fill_requests(batch, logits_output) self._free_cache(batch, page_size) - batch.seq_lens.add_(self.accept_length + 1) - batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + accept_length_cpu = self.accept_length.cpu() + num_accepted_tokens = accept_length_cpu.sum().item() - return logits_output, self.verified_id, self.accept_length.sum().item() + batch.seq_lens.add_(self.accept_length + 1) + batch.seq_lens_cpu.add_(accept_length_cpu + 1) + + return logits_output, self.verified_id, num_accepted_tokens def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): pass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index dce5db06f..9f38f149c 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit( return page_size + 2 * max(sliding_window_size, chunked_prefill_size) +def get_num_new_pages( + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + page_size: int, + decode: bool = False, +) -> torch.Tensor: + """ + Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch. + """ + cpu_device = torch.device("cpu") + assert prefix_lens.device == cpu_device + assert seq_lens.device == cpu_device + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (prefix_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + extend_lens = seq_lens - prefix_lens + sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64) + if decode: + return sum_num_new_pages.item() + merged_value = (sum_num_new_pages) << 32 | torch.sum(extend_lens).to(torch.int64) + + return merged_value.item() >> 32 + + class CachedKernel: """ Wrapper that allows kernel[grid](...) syntax with caching based on a key function.