Returning a per request metric for number of cached_tokens read (#1599)

This commit is contained in:
havetc
2024-10-16 20:49:22 +02:00
committed by GitHub
parent dbec2f1847
commit ecb8bad276
7 changed files with 245 additions and 3 deletions

View File

@@ -196,6 +196,9 @@ class Req:
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
# The number of cached tokens, that were already cached in the KV store
self.cached_tokens = 0
# For vision inputs
self.image_inputs: Optional[ImageInputs] = None
@@ -499,6 +502,13 @@ class ScheduleBatch:
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
if req.extend_logprob_start_len > 0
else 0
)
req.cached_tokens += len(req.prefix_indices) - already_computed
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len)

View File

@@ -51,6 +51,7 @@ class SchedulePolicy:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = True
if self.policy == "lpm":

View File

@@ -978,6 +978,7 @@ class Scheduler:
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"cached_tokens": req.cached_tokens,
"finish_reason": (
req.finished_reason.to_json()
if req.finished_reason is not None

View File

@@ -302,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
if not isinstance(ret, list):
ret = [ret]
if end_point == "/v1/chat/completions":
responses = v1_chat_generate_response(request, ret, to_file=True)
responses = v1_chat_generate_response(
request,
ret,
to_file=True,
cache_report=tokenizer_manager.server_args.enable_cache_report,
)
else:
responses = v1_generate_response(
request, ret, tokenizer_manager, to_file=True
@@ -970,7 +975,7 @@ def v1_chat_generate_request(
return adapted_request, all_requests
def v1_chat_generate_response(request, ret, to_file=False):
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
choices = []
for idx, ret_item in enumerate(ret):
@@ -1067,6 +1072,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
@@ -1075,6 +1081,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=(
{"cached_tokens": cached_tokens} if cache_report else None
),
),
)
return response
@@ -1240,7 +1249,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_chat_generate_response(request, ret)
response = v1_chat_generate_response(
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
)
return response

View File

@@ -76,6 +76,8 @@ class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details: Optional[Dict[str, int]] = None
class StreamOptions(BaseModel):

View File

@@ -73,6 +73,7 @@ class ServerArgs:
# Other
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
# Data parallelism
dp_size: int = 1
@@ -410,6 +411,11 @@ class ServerArgs:
default=ServerArgs.file_storage_pth,
help="The path of the file storage in backend.",
)
parser.add_argument(
"--enable-cache-report",
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
# Data parallelism
parser.add_argument(