diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 9def9d8d0..f8a35266b 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -204,7 +204,6 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): origin_input_ids=tmp_input_ids, sampling_params=sampling_params, ) - req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 @@ -248,7 +247,6 @@ def prepare_synthetic_inputs_for_latency_test( origin_input_ids=list(input_ids[i]), sampling_params=sampling_params, ) - req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cfd607cc5..15cdd555a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -539,7 +539,7 @@ class Req: # Prefix info # The indices to kv cache for the shared prefix. - self.prefix_indices: torch.Tensor = [] + self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64) # Number of tokens to run prefill. self.extend_input_len = 0 # The relative logprob_start_len in an extend batch @@ -691,11 +691,16 @@ class Req: # Whether request reached finished condition return self.finished_reason is not None - def init_next_round_input( - self, - tree_cache: Optional[BasePrefixCache] = None, - ): + def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.fill_ids = self.origin_input_ids + self.output_ids + input_len = len(self.fill_ids) + # NOTE: the matched length is at most 1 less than the input length to enable logprob computation + max_prefix_len = input_len - 1 + if self.return_logprob: + max_prefix_len = min(max_prefix_len, self.logprob_start_len) + max_prefix_len = max(max_prefix_len, 0) + token_ids = self.fill_ids[:max_prefix_len] + if tree_cache is not None: ( self.prefix_indices, @@ -703,31 +708,11 @@ class Req: self.last_host_node, self.host_hit_length, ) = tree_cache.match_prefix( - key=RadixKey( - token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key - ), + key=RadixKey(token_ids=token_ids, extra_key=self.extra_key) ) self.last_matched_prefix_len = len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) - def adjust_max_prefix_ids(self): - self.fill_ids = self.origin_input_ids + self.output_ids - input_len = len(self.fill_ids) - - # FIXME: To work around some bugs in logprob computation, we need to ensure each - # request has at least one token. Later, we can relax this requirement and use `input_len`. - max_prefix_len = input_len - 1 - - if self.sampling_params.max_new_tokens > 0: - # Need at least one token to compute logits - max_prefix_len = min(max_prefix_len, input_len - 1) - - if self.return_logprob: - max_prefix_len = min(max_prefix_len, self.logprob_start_len) - - max_prefix_len = max(max_prefix_len, 0) - return self.fill_ids[:max_prefix_len] - # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 def init_incremental_detokenize(self): first_iter = self.surr_offset is None or self.read_offset is None @@ -808,7 +793,7 @@ class Req: return def reset_for_retract(self): - self.prefix_indices = [] + self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.last_node = None self.swa_uuid_for_lock = None self.extend_input_len = 0 @@ -1124,6 +1109,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): else: return out_cache_loc + def write_cache_indices( + self, + req_pool_indices: List[int], + prefix_lens: List[int], + seq_lens: List[int], + extend_lens: List[int], + out_cache_loc: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, + extend_lens_tensor: torch.Tensor, + prefix_tensors: list[torch.Tensor], + ): + if support_triton(global_server_args_dict.get("attention_backend")): + prefix_pointers = torch.tensor( + [t.data_ptr() for t in prefix_tensors], device=self.device + ) + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) + write_req_to_token_pool_triton[(len(req_pool_indices),)]( + self.req_to_token_pool.req_to_token, + req_pool_indices_tensor, + prefix_pointers, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + out_cache_loc, + self.req_to_token_pool.req_to_token.shape[1], + ) + else: + pt = 0 + for i in range(len(req_pool_indices)): + self.req_to_token_pool.write( + (req_pool_indices[i], slice(0, prefix_lens[i])), + prefix_tensors[i], + ) + self.req_to_token_pool.write( + (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), + out_cache_loc[pt : pt + extend_lens[i]], + ) + pt += extend_lens[i] + def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] @@ -1201,10 +1227,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND - # Allocate req slots - bs = len(self.reqs) - req_pool_indices = self.alloc_req_slots(bs, self.reqs) - # Init tensors reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] @@ -1218,9 +1240,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): r.token_type_ids for r in reqs if r.token_type_ids is not None ] - req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( - self.device, non_blocking=True - ) input_ids_tensor = torch.tensor( list(chain.from_iterable(input_ids)), dtype=torch.int64 ).to(self.device, non_blocking=True) @@ -1244,7 +1263,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor - # Copy prefix and do some basic check + # Allocate req slots + bs = len(self.reqs) + req_pool_indices = self.alloc_req_slots(bs, self.reqs) + req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( + self.device, non_blocking=True + ) + + # Allocate memory + if self.token_to_kv_pool_allocator.page_size == 1: + out_cache_loc = self.alloc_token_slots(extend_num_tokens) + else: + last_loc = [ + ( + r.prefix_indices[-1:] + if len(r.prefix_indices) > 0 + else torch.tensor([-1], device=self.device) + ) + for r in self.reqs + ] + out_cache_loc = self.alloc_paged_token_slots_extend( + prefix_lens_tensor, + prefix_lens_cpu_tensor, + seq_lens_tensor, + seq_lens_cpu, + torch.cat(last_loc), + extend_num_tokens, + ) + + # Write allocated tokens to req_to_token_pool + self.write_cache_indices( + req_pool_indices, + prefix_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_pool_indices_tensor, + prefix_lens_tensor, + seq_lens_tensor, + extend_lens_tensor, + [r.prefix_indices for r in reqs], + ) + + # Set fields input_embeds = [] extend_input_logprob_token_ids = [] multimodal_inputs = [] @@ -1254,9 +1315,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): assert seq_len - pre_len == req.extend_input_len if pre_len > 0: - self.req_to_token_pool.write( - (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices - ) if isinstance(self.tree_cache, SWAChunkCache): self.tree_cache.evict_swa( req, pre_len, self.model_config.attention_chunk_size @@ -1351,25 +1409,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): else: extend_input_logprob_token_ids = None - # Allocate memory - if self.token_to_kv_pool_allocator.page_size == 1: - out_cache_loc = self.alloc_token_slots(extend_num_tokens) - else: - last_loc = get_last_loc( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - ) - out_cache_loc = self.alloc_paged_token_slots_extend( - prefix_lens_tensor, - prefix_lens_cpu_tensor, - seq_lens_tensor, - seq_lens_cpu, - last_loc, - extend_num_tokens, - ) - - # Set fields self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor self.seq_lens = seq_lens_tensor @@ -1402,28 +1441,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.extend_lens = extend_lens self.extend_input_logprob_token_ids = extend_input_logprob_token_ids - # Write to req_to_token_pool - if support_triton(global_server_args_dict.get("attention_backend")): - # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) - - write_req_to_token_pool_triton[(bs,)]( - self.req_to_token_pool.req_to_token, - req_pool_indices_tensor, - prefix_lens_tensor, - seq_lens_tensor, - extend_lens_tensor, - out_cache_loc, - self.req_to_token_pool.req_to_token.shape[1], - ) - else: - pt = 0 - for i in range(bs): - self.req_to_token_pool.write( - (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])), - out_cache_loc[pt : pt + extend_lens[i]], - ) - pt += extend_lens[i] - if self.model_config.is_encoder_decoder: self.prepare_encoder_info_extend(input_ids, seq_lens) @@ -2024,6 +2041,7 @@ class ModelWorkerBatch: def write_req_to_token_pool_triton( req_to_token_ptr, # [max_batch, max_context_len] req_pool_indices, + prefix_tensors, pre_lens, seq_lens, extend_lens, @@ -2036,6 +2054,19 @@ def write_req_to_token_pool_triton( req_pool_index = tl.load(req_pool_indices + pid) pre_len = tl.load(pre_lens + pid) seq_len = tl.load(seq_lens + pid) + prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64)) + + # write prefix + num_loop = tl.cdiv(pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < pre_len + value = tl.load(prefix_tensor + offset, mask=mask) + tl.store( + req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset, + value, + mask=mask, + ) # NOTE: This can be slow for large bs cumsum_start = tl.cast(0, tl.int64) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index fe309e3d8..2fb355b03 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -174,7 +174,7 @@ class SchedulePolicy: self.waiting_queue_radix_tree.reset() for r in waiting_queue: - prefix_ids = r.adjust_max_prefix_ids() + prefix_ids = r.origin_input_ids + r.output_ids extra_key = r.extra_key # NOTE: the prefix_indices must always be aligned with last_node diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 6ca8d9995..54626dffd 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache): ] # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later - req.prefix_indices = kv_indices + req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) def evict(self, num_tokens: int): pass diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py index 060535687..314e35ec9 100644 --- a/test/srt/test_forward_split_prefill.py +++ b/test/srt/test_forward_split_prefill.py @@ -90,7 +90,6 @@ class TestForwardSplitPrefill(CustomTestCase): origin_input_ids=list(input_ids[i]), sampling_params=sampling_params, ) - req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1