Returning a per request metric for number of cached_tokens read (#1599)
This commit is contained in:
@@ -196,6 +196,9 @@ class Req:
|
|||||||
# this does not include the jump forward tokens.
|
# this does not include the jump forward tokens.
|
||||||
self.completion_tokens_wo_jump_forward = 0
|
self.completion_tokens_wo_jump_forward = 0
|
||||||
|
|
||||||
|
# The number of cached tokens, that were already cached in the KV store
|
||||||
|
self.cached_tokens = 0
|
||||||
|
|
||||||
# For vision inputs
|
# For vision inputs
|
||||||
self.image_inputs: Optional[ImageInputs] = None
|
self.image_inputs: Optional[ImageInputs] = None
|
||||||
|
|
||||||
@@ -499,6 +502,13 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
|
already_computed = (
|
||||||
|
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||||
|
if req.extend_logprob_start_len > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
req.cached_tokens += len(req.prefix_indices) - already_computed
|
||||||
|
|
||||||
req.req_pool_idx = req_pool_indices[i]
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ class SchedulePolicy:
|
|||||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||||
rid=r.rid, key=r.adjust_max_prefix_ids()
|
rid=r.rid, key=r.adjust_max_prefix_ids()
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix_computed = True
|
prefix_computed = True
|
||||||
|
|
||||||
if self.policy == "lpm":
|
if self.policy == "lpm":
|
||||||
|
|||||||
@@ -978,6 +978,7 @@ class Scheduler:
|
|||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
"completion_tokens": len(req.output_ids),
|
"completion_tokens": len(req.output_ids),
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
|
"cached_tokens": req.cached_tokens,
|
||||||
"finish_reason": (
|
"finish_reason": (
|
||||||
req.finished_reason.to_json()
|
req.finished_reason.to_json()
|
||||||
if req.finished_reason is not None
|
if req.finished_reason is not None
|
||||||
|
|||||||
@@ -302,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
if end_point == "/v1/chat/completions":
|
if end_point == "/v1/chat/completions":
|
||||||
responses = v1_chat_generate_response(request, ret, to_file=True)
|
responses = v1_chat_generate_response(
|
||||||
|
request,
|
||||||
|
ret,
|
||||||
|
to_file=True,
|
||||||
|
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
responses = v1_generate_response(
|
responses = v1_generate_response(
|
||||||
request, ret, tokenizer_manager, to_file=True
|
request, ret, tokenizer_manager, to_file=True
|
||||||
@@ -970,7 +975,7 @@ def v1_chat_generate_request(
|
|||||||
return adapted_request, all_requests
|
return adapted_request, all_requests
|
||||||
|
|
||||||
|
|
||||||
def v1_chat_generate_response(request, ret, to_file=False):
|
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||||
choices = []
|
choices = []
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
@@ -1067,6 +1072,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
||||||
)
|
)
|
||||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||||
|
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=ret[0]["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
@@ -1075,6 +1081,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
prompt_tokens_details=(
|
||||||
|
{"cached_tokens": cached_tokens} if cache_report else None
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
@@ -1240,7 +1249,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = v1_chat_generate_response(request, ret)
|
response = v1_chat_generate_response(
|
||||||
|
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,8 @@ class UsageInfo(BaseModel):
|
|||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
# only used to return cached tokens when --enable-cache-report is set
|
||||||
|
prompt_tokens_details: Optional[Dict[str, int]] = None
|
||||||
|
|
||||||
|
|
||||||
class StreamOptions(BaseModel):
|
class StreamOptions(BaseModel):
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class ServerArgs:
|
|||||||
# Other
|
# Other
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
file_storage_pth: str = "SGLang_storage"
|
file_storage_pth: str = "SGLang_storage"
|
||||||
|
enable_cache_report: bool = False
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
@@ -410,6 +411,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.file_storage_pth,
|
default=ServerArgs.file_storage_pth,
|
||||||
help="The path of the file storage in backend.",
|
help="The path of the file storage in backend.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-cache-report",
|
||||||
|
action="store_true",
|
||||||
|
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
||||||
|
)
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
211
test/srt/test_cache_report.py
Normal file
211
test/srt/test_cache_report.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheReport(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.min_cached = 5
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=300,
|
||||||
|
other_args=[
|
||||||
|
"--chunked-prefill-size=40",
|
||||||
|
"--enable-cache-report",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||||
|
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||||
|
|
||||||
|
usage = cls.run_openai(cls, "1").usage
|
||||||
|
# we can assume that our request is of size 1, plus the total template size
|
||||||
|
# ideally we would like to know the begin size / end size of the template to be more precise
|
||||||
|
total_template_size = usage.prompt_tokens - 1
|
||||||
|
print(f"template size: {total_template_size}")
|
||||||
|
usage2 = cls.run_openai(cls, "2").usage
|
||||||
|
assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
|
||||||
|
cls.min_cached = max(
|
||||||
|
usage2.prompt_tokens_details.cached_tokens,
|
||||||
|
total_template_size - usage2.prompt_tokens_details.cached_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
# we use an uncommon start to minimise the chance that the cache is hit by chance
|
||||||
|
json={
|
||||||
|
"text": "_ The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": 128,
|
||||||
|
"n": n,
|
||||||
|
"stop_token_ids": [119690],
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def run_openai(self, message):
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
# {"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": message},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def run_openai_async(self, message):
|
||||||
|
response = await self.aclient.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": message},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def cache_report_openai(self, message):
|
||||||
|
response = self.run_openai(message)
|
||||||
|
print(
|
||||||
|
f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
|
||||||
|
)
|
||||||
|
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||||
|
# assert int(response.usage.cached_tokens) == 0
|
||||||
|
assert first_cached_tokens < self.min_cached
|
||||||
|
response = self.run_openai(message)
|
||||||
|
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||||
|
print(f"openai second request cached_tokens: {cached_tokens}")
|
||||||
|
assert cached_tokens > 0
|
||||||
|
assert cached_tokens == int(response.usage.prompt_tokens) - 1
|
||||||
|
return first_cached_tokens
|
||||||
|
|
||||||
|
async def cache_report_openai_async(self, message):
|
||||||
|
response = await self.run_openai_async(message)
|
||||||
|
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||||
|
prompt_tokens = int(response.usage.prompt_tokens)
|
||||||
|
return cached_tokens, prompt_tokens
|
||||||
|
|
||||||
|
def test_generate(self):
|
||||||
|
print("=" * 100)
|
||||||
|
response = self.run_decode()
|
||||||
|
# print(response.json())
|
||||||
|
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||||
|
print(f"sglang first request cached_tokens: {cached_tokens}")
|
||||||
|
print(
|
||||||
|
f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||||
|
)
|
||||||
|
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
|
||||||
|
assert cached_tokens < self.min_cached
|
||||||
|
response = self.run_decode()
|
||||||
|
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||||
|
print(f"sglang second request cached_tokens: {cached_tokens}")
|
||||||
|
print(
|
||||||
|
f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||||
|
)
|
||||||
|
assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1
|
||||||
|
|
||||||
|
def test_cache_split_prefill_openai(self):
|
||||||
|
print("=" * 100)
|
||||||
|
self.cache_report_openai(
|
||||||
|
"€ This is a very long and unique text that should not be already cached, the twist is"
|
||||||
|
" that it should be longer than the chunked-prefill-size, so it should be split among"
|
||||||
|
" several prefill requests. Still, it shouldn't be cached"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cache_report_openai(self):
|
||||||
|
print("=" * 100)
|
||||||
|
# warm up the cache, for the template
|
||||||
|
self.run_openai("Introduce the capital of France.")
|
||||||
|
|
||||||
|
first_cached_tokens_1 = self.run_openai(
|
||||||
|
"How many sparrow do you need to lift a coconut?"
|
||||||
|
).usage.prompt_tokens_details.cached_tokens
|
||||||
|
|
||||||
|
usage_2 = self.run_openai("* sing something about cats").usage
|
||||||
|
first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
|
||||||
|
# first request may not have 0 cached tokens, but if they only have the template in common they
|
||||||
|
# should be the same once the cache is warmed up
|
||||||
|
assert first_cached_tokens_1 == first_cached_tokens_2
|
||||||
|
|
||||||
|
resp = self.run_openai("* sing something about cats and dogs")
|
||||||
|
print(resp.usage)
|
||||||
|
|
||||||
|
resp = self.run_openai("* sing something about cats, please")
|
||||||
|
print(resp.usage)
|
||||||
|
assert (
|
||||||
|
resp.usage.prompt_tokens_details.cached_tokens
|
||||||
|
>= usage_2.prompt_tokens - self.min_cached
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cache_report_openai_async(self):
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
async def run_test():
|
||||||
|
task0 = asyncio.create_task(
|
||||||
|
self.cache_report_openai_async(
|
||||||
|
"first request, to start the inference and let the next two request be started in the same batch"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05) # to force the first request to be started first
|
||||||
|
task1 = asyncio.create_task(
|
||||||
|
self.cache_report_openai_async(
|
||||||
|
"> can the same batch parallel request use the cache?"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
task2 = asyncio.create_task(
|
||||||
|
self.cache_report_openai_async(
|
||||||
|
"> can the same batch parallel request use the cache?"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result0, result1, result2 = await asyncio.gather(task0, task1, task2)
|
||||||
|
|
||||||
|
cached_tokens0, prompt_tokens0 = result0
|
||||||
|
cached_tokens1, prompt_tokens1 = result1
|
||||||
|
cached_tokens2, prompt_tokens2 = result2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch)
|
||||||
|
# If a new optimisation limiting starting request with same prefix at the same time was added
|
||||||
|
# to maximise the cache hit, this would not be true
|
||||||
|
assert cached_tokens1 == cached_tokens2 == cached_tokens0
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user