Fix input_ids && rename to fill_ids (#1021)
This commit is contained in:
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|||||||
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.sampling_params = sampling_params
|
req.sampling_params = sampling_params
|
||||||
req.input_ids = req.origin_input_ids
|
req.fill_ids = req.origin_input_ids
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
return input_ids, reqs
|
return input_ids, reqs
|
||||||
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|||||||
):
|
):
|
||||||
for i in range(len(reqs)):
|
for i in range(len(reqs)):
|
||||||
req = reqs[i]
|
req = reqs[i]
|
||||||
req.input_ids += input_ids[i][bench_args.cut_len :]
|
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
||||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||||
i, : bench_args.cut_len
|
i, : bench_args.cut_len
|
||||||
]
|
]
|
||||||
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|||||||
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
||||||
req.prefix_indices = []
|
req.prefix_indices = []
|
||||||
req.sampling_params = sampling_params
|
req.sampling_params = sampling_params
|
||||||
req.input_ids = req.origin_input_ids
|
req.fill_ids = req.origin_input_ids
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|
||||||
return reqs
|
return reqs
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class PrefillAdder:
|
|||||||
def add_inflight_req(self, req: Req):
|
def add_inflight_req(self, req: Req):
|
||||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
|
|
||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
@@ -193,7 +193,7 @@ class PrefillAdder:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_inflight_req = req
|
self.new_inflight_req = req
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class Req:
|
|||||||
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
||||||
self.origin_input_ids = origin_input_ids
|
self.origin_input_ids = origin_input_ids
|
||||||
self.output_ids = [] # Each decode stage's output ids
|
self.output_ids = [] # Each decode stage's output ids
|
||||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||||
|
|
||||||
# Memory info
|
# Memory info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
@@ -165,12 +165,12 @@ class Req:
|
|||||||
return self.finished_reason is not None
|
return self.finished_reason is not None
|
||||||
|
|
||||||
def init_next_round_input(self):
|
def init_next_round_input(self):
|
||||||
self.input_ids = self.origin_input_ids + self.output_ids
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||||
self.extend_input_len = len(self.input_ids) - len(self.prefix_indices)
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||||
|
|
||||||
def adjust_max_prefix_ids(self):
|
def adjust_max_prefix_ids(self):
|
||||||
self.input_ids = self.origin_input_ids + self.output_ids
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||||
input_len = len(self.input_ids)
|
input_len = len(self.fill_ids)
|
||||||
max_prefix_len = input_len
|
max_prefix_len = input_len
|
||||||
|
|
||||||
if self.sampling_params.max_new_tokens > 0:
|
if self.sampling_params.max_new_tokens > 0:
|
||||||
@@ -184,7 +184,7 @@ class Req:
|
|||||||
# Need at least two tokens to compute normalized logprob
|
# Need at least two tokens to compute normalized logprob
|
||||||
max_prefix_len = min(max_prefix_len, input_len - 2)
|
max_prefix_len = min(max_prefix_len, input_len - 2)
|
||||||
|
|
||||||
return self.input_ids[:max_prefix_len]
|
return self.fill_ids[:max_prefix_len]
|
||||||
|
|
||||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||||
def init_incremental_detokenize(self):
|
def init_incremental_detokenize(self):
|
||||||
@@ -427,7 +427,7 @@ class ScheduleBatch:
|
|||||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||||
bs = self.batch_size()
|
bs = self.batch_size()
|
||||||
reqs = self.reqs
|
reqs = self.reqs
|
||||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||||
seq_lens = []
|
seq_lens = []
|
||||||
|
|
||||||
@@ -438,7 +438,7 @@ class ScheduleBatch:
|
|||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.input_ids)
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
ext_len = seq_len - pre_len
|
ext_len = seq_len - pre_len
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
|
|
||||||
@@ -632,7 +632,8 @@ class ScheduleBatch:
|
|||||||
def prepare_for_decode(self, input_ids=None):
|
def prepare_for_decode(self, input_ids=None):
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = [
|
input_ids = [
|
||||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
||||||
|
for r in self.reqs
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
||||||
|
|||||||
@@ -515,7 +515,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
def add_logprob_return_values(self, i, req: Req, pt, next_token_ids, output):
|
||||||
if req.normalized_prompt_logprob is None:
|
if req.normalized_prompt_logprob is None:
|
||||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||||
|
|
||||||
@@ -524,12 +524,12 @@ class ModelTpServer:
|
|||||||
req.input_token_logprobs = list(
|
req.input_token_logprobs = list(
|
||||||
zip(
|
zip(
|
||||||
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||||
req.input_ids[-req.extend_input_len + 1 :],
|
req.fill_ids[-req.extend_input_len + 1 :],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if req.logprob_start_len == 0:
|
if req.logprob_start_len == 0:
|
||||||
req.input_token_logprobs = [
|
req.input_token_logprobs = [
|
||||||
(None, req.input_ids[0])
|
(None, req.fill_ids[0])
|
||||||
] + req.input_token_logprobs
|
] + req.input_token_logprobs
|
||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
if req.last_update_decode_tokens != 0:
|
||||||
@@ -543,7 +543,7 @@ class ModelTpServer:
|
|||||||
+ req.extend_input_len
|
+ req.extend_input_len
|
||||||
- 1
|
- 1
|
||||||
],
|
],
|
||||||
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
req.fill_ids[-req.last_update_decode_tokens + 1 :],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
|
|
||||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = (req.input_ids + req.output_ids)[:-1]
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, : len(token_ids)
|
||||||
@@ -45,7 +45,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
|
|
||||||
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = req.input_ids
|
token_ids = req.fill_ids
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, : len(token_ids)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||||
"""Cache request when it finishes."""
|
"""Cache request when it finishes."""
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = (req.input_ids + req.output_ids)[:-1]
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, : len(token_ids)
|
||||||
]
|
]
|
||||||
@@ -116,7 +116,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = req.input_ids
|
token_ids = req.fill_ids
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, : len(token_ids)
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class InputMetadata:
|
|||||||
self.positions = torch.tensor(
|
self.positions = torch.tensor(
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
[
|
[
|
||||||
np.arange(len(req.prefix_indices), len(req.input_ids))
|
np.arange(len(req.prefix_indices), len(req.fill_ids))
|
||||||
for req in batch.reqs
|
for req in batch.reqs
|
||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
@@ -124,7 +124,7 @@ class InputMetadata:
|
|||||||
[
|
[
|
||||||
np.arange(
|
np.arange(
|
||||||
len(req.prefix_indices) + position_ids_offsets_cpu[i],
|
len(req.prefix_indices) + position_ids_offsets_cpu[i],
|
||||||
len(req.input_ids) + position_ids_offsets_cpu[i],
|
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
||||||
)
|
)
|
||||||
for i, req in enumerate(batch.reqs)
|
for i, req in enumerate(batch.reqs)
|
||||||
],
|
],
|
||||||
@@ -141,7 +141,7 @@ class InputMetadata:
|
|||||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||||
else:
|
else:
|
||||||
prefix_lens_cpu = [
|
prefix_lens_cpu = [
|
||||||
len(r.input_ids) - len(r.prefix_indices) for r in batch.reqs
|
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
||||||
]
|
]
|
||||||
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
@@ -149,7 +149,7 @@ class InputMetadata:
|
|||||||
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
||||||
|
|
||||||
def init_total_num_tokens(self, batch: ScheduleBatch):
|
def init_total_num_tokens(self, batch: ScheduleBatch):
|
||||||
self.total_num_tokens = sum(len(req.input_ids) for req in batch.reqs)
|
self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
@@ -203,7 +203,7 @@ class InputMetadata:
|
|||||||
|
|
||||||
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
self.triton_max_seq_len = max(len(r.input_ids) for r in batch.reqs)
|
self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs)
|
||||||
self.triton_prefix_lens = prefix_lens
|
self.triton_prefix_lens = prefix_lens
|
||||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user