Returning a per request metric for number of cached_tokens read (#1599)

This commit is contained in:
havetc
2024-10-16 20:49:22 +02:00
committed by GitHub
parent dbec2f1847
commit ecb8bad276
7 changed files with 245 additions and 3 deletions

View File

@@ -196,6 +196,9 @@ class Req:
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
# The number of cached tokens, that were already cached in the KV store
self.cached_tokens = 0
# For vision inputs
self.image_inputs: Optional[ImageInputs] = None
@@ -499,6 +502,13 @@ class ScheduleBatch:
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
if req.extend_logprob_start_len > 0
else 0
)
req.cached_tokens += len(req.prefix_indices) - already_computed
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len)