From ecb8bad276ea13e243a36cc23adca8207fac4657 Mon Sep 17 00:00:00 2001 From: havetc Date: Wed, 16 Oct 2024 20:49:22 +0200 Subject: [PATCH] Returning a per request metric for number of cached_tokens read (#1599) --- python/sglang/srt/managers/schedule_batch.py | 10 + python/sglang/srt/managers/schedule_policy.py | 1 + python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/openai_api/adapter.py | 17 +- python/sglang/srt/openai_api/protocol.py | 2 + python/sglang/srt/server_args.py | 6 + test/srt/test_cache_report.py | 211 ++++++++++++++++++ 7 files changed, 245 insertions(+), 3 deletions(-) create mode 100644 test/srt/test_cache_report.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8684fa5ce..0eeb3359e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -196,6 +196,9 @@ class Req: # this does not include the jump forward tokens. self.completion_tokens_wo_jump_forward = 0 + # The number of cached tokens, that were already cached in the KV store + self.cached_tokens = 0 + # For vision inputs self.image_inputs: Optional[ImageInputs] = None @@ -499,6 +502,13 @@ class ScheduleBatch: pt = 0 for i, req in enumerate(reqs): + already_computed = ( + req.extend_logprob_start_len + 1 + req.cached_tokens + if req.extend_logprob_start_len > 0 + else 0 + ) + req.cached_tokens += len(req.prefix_indices) - already_computed + req.req_pool_idx = req_pool_indices[i] pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) seq_lens.append(seq_len) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 74a3a621c..45c9be37a 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -51,6 +51,7 @@ class SchedulePolicy: r.prefix_indices, r.last_node = self.tree_cache.match_prefix( rid=r.rid, key=r.adjust_max_prefix_ids() ) + prefix_computed = True if self.policy == "lpm": diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 18ca20409..9f6989c25 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -978,6 +978,7 @@ class Scheduler: "prompt_tokens": len(req.origin_input_ids), "completion_tokens": len(req.output_ids), "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, + "cached_tokens": req.cached_tokens, "finish_reason": ( req.finished_reason.to_json() if req.finished_reason is not None diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5c8990c69..4e70546df 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -302,7 +302,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe if not isinstance(ret, list): ret = [ret] if end_point == "/v1/chat/completions": - responses = v1_chat_generate_response(request, ret, to_file=True) + responses = v1_chat_generate_response( + request, + ret, + to_file=True, + cache_report=tokenizer_manager.server_args.enable_cache_report, + ) else: responses = v1_generate_response( request, ret, tokenizer_manager, to_file=True @@ -970,7 +975,7 @@ def v1_chat_generate_request( return adapted_request, all_requests -def v1_chat_generate_response(request, ret, to_file=False): +def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): choices = [] for idx, ret_item in enumerate(ret): @@ -1067,6 +1072,7 @@ def v1_chat_generate_response(request, ret, to_file=False): ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) ) completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) + cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret) response = ChatCompletionResponse( id=ret[0]["meta_info"]["id"], model=request.model, @@ -1075,6 +1081,9 @@ def v1_chat_generate_response(request, ret, to_file=False): prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=( + {"cached_tokens": cached_tokens} if cache_report else None + ), ), ) return response @@ -1240,7 +1249,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - response = v1_chat_generate_response(request, ret) + response = v1_chat_generate_response( + request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report + ) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 1da27af28..349944f70 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -76,6 +76,8 @@ class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 + # only used to return cached tokens when --enable-cache-report is set + prompt_tokens_details: Optional[Dict[str, int]] = None class StreamOptions(BaseModel): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index eafdede82..10f63e697 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -73,6 +73,7 @@ class ServerArgs: # Other api_key: Optional[str] = None file_storage_pth: str = "SGLang_storage" + enable_cache_report: bool = False # Data parallelism dp_size: int = 1 @@ -410,6 +411,11 @@ class ServerArgs: default=ServerArgs.file_storage_pth, help="The path of the file storage in backend.", ) + parser.add_argument( + "--enable-cache-report", + action="store_true", + help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", + ) # Data parallelism parser.add_argument( diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py new file mode 100644 index 000000000..1d8e9a4a0 --- /dev/null +++ b/test/srt/test_cache_report.py @@ -0,0 +1,211 @@ +import asyncio +import json +import unittest + +import openai +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestCacheReport(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.min_cached = 5 + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=[ + "--chunked-prefill-size=40", + "--enable-cache-report", + ], + ) + cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1") + cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1") + + usage = cls.run_openai(cls, "1").usage + # we can assume that our request is of size 1, plus the total template size + # ideally we would like to know the begin size / end size of the template to be more precise + total_template_size = usage.prompt_tokens - 1 + print(f"template size: {total_template_size}") + usage2 = cls.run_openai(cls, "2").usage + assert usage2.prompt_tokens_details.cached_tokens <= total_template_size + cls.min_cached = max( + usage2.prompt_tokens_details.cached_tokens, + total_template_size - usage2.prompt_tokens_details.cached_tokens, + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): + response = requests.post( + self.base_url + "/generate", + # we use an uncommon start to minimise the chance that the cache is hit by chance + json={ + "text": "_ The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "stop_token_ids": [119690], + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + return response + + def run_openai(self, message): + response = self.client.chat.completions.create( + model=self.model, + messages=[ + # {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": message}, + ], + temperature=0, + max_tokens=100, + ) + return response + + async def run_openai_async(self, message): + response = await self.aclient.chat.completions.create( + model=self.model, + messages=[ + {"role": "user", "content": message}, + ], + temperature=0, + max_tokens=100, + ) + return response + + def cache_report_openai(self, message): + response = self.run_openai(message) + print( + f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}" + ) + first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens) + # assert int(response.usage.cached_tokens) == 0 + assert first_cached_tokens < self.min_cached + response = self.run_openai(message) + cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens) + print(f"openai second request cached_tokens: {cached_tokens}") + assert cached_tokens > 0 + assert cached_tokens == int(response.usage.prompt_tokens) - 1 + return first_cached_tokens + + async def cache_report_openai_async(self, message): + response = await self.run_openai_async(message) + cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens) + prompt_tokens = int(response.usage.prompt_tokens) + return cached_tokens, prompt_tokens + + def test_generate(self): + print("=" * 100) + response = self.run_decode() + # print(response.json()) + cached_tokens = int(response.json()["meta_info"]["cached_tokens"]) + print(f"sglang first request cached_tokens: {cached_tokens}") + print( + f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}" + ) + # can't assure to be 0: depends on the initialisation request / if a template is used with the model + assert cached_tokens < self.min_cached + response = self.run_decode() + cached_tokens = int(response.json()["meta_info"]["cached_tokens"]) + print(f"sglang second request cached_tokens: {cached_tokens}") + print( + f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}" + ) + assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1 + + def test_cache_split_prefill_openai(self): + print("=" * 100) + self.cache_report_openai( + "€ This is a very long and unique text that should not be already cached, the twist is" + " that it should be longer than the chunked-prefill-size, so it should be split among" + " several prefill requests. Still, it shouldn't be cached" + ) + + def test_cache_report_openai(self): + print("=" * 100) + # warm up the cache, for the template + self.run_openai("Introduce the capital of France.") + + first_cached_tokens_1 = self.run_openai( + "How many sparrow do you need to lift a coconut?" + ).usage.prompt_tokens_details.cached_tokens + + usage_2 = self.run_openai("* sing something about cats").usage + first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens + # first request may not have 0 cached tokens, but if they only have the template in common they + # should be the same once the cache is warmed up + assert first_cached_tokens_1 == first_cached_tokens_2 + + resp = self.run_openai("* sing something about cats and dogs") + print(resp.usage) + + resp = self.run_openai("* sing something about cats, please") + print(resp.usage) + assert ( + resp.usage.prompt_tokens_details.cached_tokens + >= usage_2.prompt_tokens - self.min_cached + ) + + def test_cache_report_openai_async(self): + print("=" * 100) + + async def run_test(): + task0 = asyncio.create_task( + self.cache_report_openai_async( + "first request, to start the inference and let the next two request be started in the same batch" + ) + ) + await asyncio.sleep(0.05) # to force the first request to be started first + task1 = asyncio.create_task( + self.cache_report_openai_async( + "> can the same batch parallel request use the cache?" + ) + ) + task2 = asyncio.create_task( + self.cache_report_openai_async( + "> can the same batch parallel request use the cache?" + ) + ) + result0, result1, result2 = await asyncio.gather(task0, task1, task2) + + cached_tokens0, prompt_tokens0 = result0 + cached_tokens1, prompt_tokens1 = result1 + cached_tokens2, prompt_tokens2 = result2 + + print( + f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}" + ) + print( + f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}" + ) + print( + f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}" + ) + + # Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch) + # If a new optimisation limiting starting request with same prefix at the same time was added + # to maximise the cache hit, this would not be true + assert cached_tokens1 == cached_tokens2 == cached_tokens0 + + asyncio.run(run_test()) + + +if __name__ == "__main__": + unittest.main()