Add a test case for cached_tokens (#3145)
This commit is contained in:
@@ -331,6 +331,7 @@ class Req:
|
||||
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
self.already_computed = 0
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
@@ -750,13 +751,6 @@ class ScheduleBatch:
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
already_computed = (
|
||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||
if req.extend_logprob_start_len > 0
|
||||
else 0
|
||||
)
|
||||
req.cached_tokens += len(req.prefix_indices) - already_computed
|
||||
|
||||
req.req_pool_idx = req_pool_indices[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||
seq_lens.append(seq_len)
|
||||
@@ -772,15 +766,20 @@ class ScheduleBatch:
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||
)
|
||||
else:
|
||||
extend_logprob_start_len = req.extend_input_len - 1
|
||||
if req.return_logprob:
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
||||
)
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
|
||||
|
||||
@@ -660,24 +660,23 @@ class Scheduler:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Copy more attributes
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
|
||||
if req.logprob_start_len == -1:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
|
||||
# Validate prompts length
|
||||
error_msg = validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
|
||||
if error_msg:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Copy more attributes
|
||||
if recv_req.logprob_start_len == -1:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
else:
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
(
|
||||
req.sampling_params.max_new_tokens
|
||||
@@ -725,12 +724,17 @@ class Scheduler:
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# Validate prompts length
|
||||
validate_input_length(
|
||||
error_msg = validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
if error_msg:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Copy more attributes
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
||||
@@ -1044,26 +1048,23 @@ class Scheduler:
|
||||
self.forward_ct += 1
|
||||
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
else:
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
num_accepted_tokens + batch.batch_size()
|
||||
)
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
else:
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
num_accepted_tokens + batch.batch_size()
|
||||
)
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
batch.output_ids = next_token_ids
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
@@ -1072,7 +1073,6 @@ class Scheduler:
|
||||
bid=model_worker_batch.bid,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = EmbeddingBatchResult(
|
||||
|
||||
Reference in New Issue
Block a user