From 43fbb6d919d9b6c07ab256a8ab04bc4d7462df66 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 10 Aug 2024 16:24:12 -0700 Subject: [PATCH] Fix `input_ids` && rename to `fill_ids` (#1021) --- python/sglang/bench_latency.py | 6 +++--- .../sglang/srt/managers/policy_scheduler.py | 4 ++-- python/sglang/srt/managers/schedule_batch.py | 19 ++++++++++--------- python/sglang/srt/managers/tp_worker.py | 8 ++++---- python/sglang/srt/mem_cache/chunk_cache.py | 4 ++-- python/sglang/srt/mem_cache/radix_cache.py | 4 ++-- .../srt/model_executor/forward_batch_info.py | 10 +++++----- 7 files changed, 28 insertions(+), 27 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 130f2b82f..c2b956e1d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids) req.prefix_indices = [] req.sampling_params = sampling_params - req.input_ids = req.origin_input_ids + req.fill_ids = req.origin_input_ids reqs.append(req) return input_ids, reqs @@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test( ): for i in range(len(reqs)): req = reqs[i] - req.input_ids += input_ids[i][bench_args.cut_len :] + req.fill_ids += input_ids[i][bench_args.cut_len :] req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ i, : bench_args.cut_len ] @@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i])) req.prefix_indices = [] req.sampling_params = sampling_params - req.input_ids = req.origin_input_ids + req.fill_ids = req.origin_input_ids reqs.append(req) return reqs diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index a05ba9c9c..e4a22242f 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -138,7 +138,7 @@ class PrefillAdder: def add_inflight_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) - req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len] + req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] self.can_run_list.append(req) self._prefill_one_req( @@ -193,7 +193,7 @@ class PrefillAdder: return False req.extend_input_len = trunc_len - req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len] + req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] self.can_run_list.append(req) self.new_inflight_req = req self.tree_cache.inc_lock_ref(req.last_node) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 278ed006e..a62e612b0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -99,7 +99,7 @@ class Req: self.origin_input_ids_unpadded = origin_input_ids # Before image padding self.origin_input_ids = origin_input_ids self.output_ids = [] # Each decode stage's output ids - self.input_ids = None # input_ids = origin_input_ids + output_ids + self.fill_ids = None # fill_ids = origin_input_ids + output_ids # Memory info self.req_pool_idx = None @@ -165,12 +165,12 @@ class Req: return self.finished_reason is not None def init_next_round_input(self): - self.input_ids = self.origin_input_ids + self.output_ids - self.extend_input_len = len(self.input_ids) - len(self.prefix_indices) + self.fill_ids = self.origin_input_ids + self.output_ids + self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): - self.input_ids = self.origin_input_ids + self.output_ids - input_len = len(self.input_ids) + self.fill_ids = self.origin_input_ids + self.output_ids + input_len = len(self.fill_ids) max_prefix_len = input_len if self.sampling_params.max_new_tokens > 0: @@ -184,7 +184,7 @@ class Req: # Need at least two tokens to compute normalized logprob max_prefix_len = min(max_prefix_len, input_len - 2) - return self.input_ids[:max_prefix_len] + return self.fill_ids[:max_prefix_len] # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 def init_incremental_detokenize(self): @@ -427,7 +427,7 @@ class ScheduleBatch: def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): bs = self.batch_size() reqs = self.reqs - input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [] @@ -438,7 +438,7 @@ class ScheduleBatch: pt = 0 for i, req in enumerate(reqs): req.req_pool_idx = req_pool_indices_cpu[i] - pre_len, seq_len = len(req.prefix_indices), len(req.input_ids) + pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) ext_len = seq_len - pre_len seq_lens.append(seq_len) @@ -632,7 +632,8 @@ class ScheduleBatch: def prepare_for_decode(self, input_ids=None): if input_ids is None: input_ids = [ - r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs + r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] + for r in self.reqs ] else: self.penalizer_orchestrator.cumulate_input_tokens(input_ids) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index e425a3c37..ee315bf1d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -515,7 +515,7 @@ class ModelTpServer: self.handle_finished_requests(batch) - def add_logprob_return_values(self, i, req, pt, next_token_ids, output): + def add_logprob_return_values(self, i, req: Req, pt, next_token_ids, output): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] @@ -524,12 +524,12 @@ class ModelTpServer: req.input_token_logprobs = list( zip( output.input_token_logprobs[pt : pt + req.extend_input_len - 1], - req.input_ids[-req.extend_input_len + 1 :], + req.fill_ids[-req.extend_input_len + 1 :], ) ) if req.logprob_start_len == 0: req.input_token_logprobs = [ - (None, req.input_ids[0]) + (None, req.fill_ids[0]) ] + req.input_token_logprobs if req.last_update_decode_tokens != 0: @@ -543,7 +543,7 @@ class ModelTpServer: + req.extend_input_len - 1 ], - req.input_ids[-req.last_update_decode_tokens + 1 :], + req.fill_ids[-req.last_update_decode_tokens + 1 :], ) ) ) diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 100cbbaec..f8d7dd234 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -34,7 +34,7 @@ class ChunkCache(BasePrefixCache): def cache_finished_req(self, req: "Req", token_ids=None): if token_ids is None: - token_ids = (req.input_ids + req.output_ids)[:-1] + token_ids = (req.origin_input_ids + req.output_ids)[:-1] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) @@ -45,7 +45,7 @@ class ChunkCache(BasePrefixCache): def cache_unfinished_req(self, req: "Req", token_ids=None): if token_ids is None: - token_ids = req.input_ids + token_ids = req.fill_ids kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 05cbb2c92..25a467304 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -92,7 +92,7 @@ class RadixCache(BasePrefixCache): def cache_finished_req(self, req: "Req", token_ids=None): """Cache request when it finishes.""" if token_ids is None: - token_ids = (req.input_ids + req.output_ids)[:-1] + token_ids = (req.origin_input_ids + req.output_ids)[:-1] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] @@ -116,7 +116,7 @@ class RadixCache(BasePrefixCache): return if token_ids is None: - token_ids = req.input_ids + token_ids = req.fill_ids kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 686e7ed86..dd2b59728 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -109,7 +109,7 @@ class InputMetadata: self.positions = torch.tensor( np.concatenate( [ - np.arange(len(req.prefix_indices), len(req.input_ids)) + np.arange(len(req.prefix_indices), len(req.fill_ids)) for req in batch.reqs ], axis=0, @@ -124,7 +124,7 @@ class InputMetadata: [ np.arange( len(req.prefix_indices) + position_ids_offsets_cpu[i], - len(req.input_ids) + position_ids_offsets_cpu[i], + len(req.fill_ids) + position_ids_offsets_cpu[i], ) for i, req in enumerate(batch.reqs) ], @@ -141,7 +141,7 @@ class InputMetadata: self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None else: prefix_lens_cpu = [ - len(r.input_ids) - len(r.prefix_indices) for r in batch.reqs + len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs ] self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda") self.extend_start_loc = torch.zeros_like(self.seq_lens) @@ -149,7 +149,7 @@ class InputMetadata: self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu) def init_total_num_tokens(self, batch: ScheduleBatch): - self.total_num_tokens = sum(len(req.input_ids) for req in batch.reqs) + self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs) @classmethod def from_schedule_batch( @@ -203,7 +203,7 @@ class InputMetadata: def init_triton_args(self, batch: ScheduleBatch, prefix_lens): """Init auxiliary variables for triton attention backend.""" - self.triton_max_seq_len = max(len(r.input_ids) for r in batch.reqs) + self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs) self.triton_prefix_lens = prefix_lens self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)