Add a test case for cached_tokens (#3145)

This commit is contained in:
Lianmin Zheng
2025-01-26 01:39:28 -08:00
committed by GitHub
parent f8b28e461a
commit d1a0863251
6 changed files with 74 additions and 63 deletions

View File

@@ -331,6 +331,7 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0
self.already_computed = 0
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
@@ -750,13 +751,6 @@ class ScheduleBatch:
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
if req.extend_logprob_start_len > 0
else 0
)
req.cached_tokens += len(req.prefix_indices) - already_computed
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len)
@@ -772,15 +766,20 @@ class ScheduleBatch:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
extend_logprob_start_len = req.extend_input_len - 1
if req.return_logprob:
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len
req.extend_logprob_start_len = extend_logprob_start_len
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)

View File

@@ -660,24 +660,23 @@ class Scheduler:
self.waiting_queue.append(req)
return
# Copy more attributes
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self.waiting_queue.append(req)
return
# Copy more attributes
if recv_req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
@@ -725,12 +724,17 @@ class Scheduler:
req.tokenizer = self.tokenizer
# Validate prompts length
validate_input_length(
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self.waiting_queue.append(req)
return
# Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1
self.waiting_queue.append(req)
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
@@ -1044,26 +1048,23 @@ class Scheduler:
self.forward_ct += 1
if self.is_generation:
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids
ret = GenerationBatchResult(
@@ -1072,7 +1073,6 @@ class Scheduler:
bid=model_worker_batch.bid,
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(