From d1a08632519e7c950998d44475172c4d53e9b0c3 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 26 Jan 2025 01:39:28 -0800 Subject: [PATCH] Add a test case for cached_tokens (#3145) --- README.md | 10 ++-- python/sglang/srt/managers/schedule_batch.py | 29 +++++----- python/sglang/srt/managers/scheduler.py | 58 ++++++++++---------- test/srt/run_suite.py | 1 - test/srt/test_ebnf_constrained.py | 7 --- test/srt/test_srt_endpoint.py | 32 +++++++++-- 6 files changed, 74 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 1165826c5..63b2124bf 100644 --- a/README.md +++ b/README.md @@ -19,16 +19,16 @@ | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News -- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). -- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). -- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). -- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). +- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeekSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html)) +- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). +- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). +- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
More +- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). -- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6c44b17ff..2a342c5df 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -331,6 +331,7 @@ class Req: # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 + self.already_computed = 0 def extend_image_inputs(self, image_inputs): if self.image_inputs is None: @@ -750,13 +751,6 @@ class ScheduleBatch: pt = 0 for i, req in enumerate(reqs): - already_computed = ( - req.extend_logprob_start_len + 1 + req.cached_tokens - if req.extend_logprob_start_len > 0 - else 0 - ) - req.cached_tokens += len(req.prefix_indices) - already_computed - req.req_pool_idx = req_pool_indices[i] pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) seq_lens.append(seq_len) @@ -772,15 +766,20 @@ class ScheduleBatch: # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting - # Compute the relative logprob_start_len in an extend batch - if req.logprob_start_len >= pre_len: - extend_logprob_start_len = min( - req.logprob_start_len - pre_len, req.extend_input_len - 1 - ) - else: - extend_logprob_start_len = req.extend_input_len - 1 + if req.return_logprob: + # Compute the relative logprob_start_len in an extend batch + if req.logprob_start_len >= pre_len: + extend_logprob_start_len = min( + req.logprob_start_len - pre_len, req.extend_input_len - 1 + ) + else: + raise RuntimeError( + f"This should never happen. {req.logprob_start_len=}, {pre_len=}" + ) + req.extend_logprob_start_len = extend_logprob_start_len - req.extend_logprob_start_len = extend_logprob_start_len + req.cached_tokens += pre_len - req.already_computed + req.already_computed = seq_len req.is_retracted = False pre_lens.append(pre_len) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 85bd1c2a4..9cfa14c30 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -660,24 +660,23 @@ class Scheduler: self.waiting_queue.append(req) return - # Copy more attributes - req.logprob_start_len = recv_req.logprob_start_len - - if req.logprob_start_len == -1: - # By default, only return the logprobs for output tokens - req.logprob_start_len = len(req.origin_input_ids) - 1 - # Validate prompts length error_msg = validate_input_length( req, self.max_req_input_len, self.server_args.allow_auto_truncate, ) - if error_msg: self.waiting_queue.append(req) return + # Copy more attributes + if recv_req.logprob_start_len == -1: + # By default, only return the logprobs for output tokens + req.logprob_start_len = len(req.origin_input_ids) - 1 + else: + req.logprob_start_len = recv_req.logprob_start_len + req.sampling_params.max_new_tokens = min( ( req.sampling_params.max_new_tokens @@ -725,12 +724,17 @@ class Scheduler: req.tokenizer = self.tokenizer # Validate prompts length - validate_input_length( + error_msg = validate_input_length( req, self.max_req_input_len, self.server_args.allow_auto_truncate, ) + if error_msg: + self.waiting_queue.append(req) + return + # Copy more attributes + req.logprob_start_len = len(req.origin_input_ids) - 1 self.waiting_queue.append(req) def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): @@ -1044,26 +1048,23 @@ class Scheduler: self.forward_ct += 1 if self.is_generation: - if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0: - if self.spec_algorithm.is_none(): - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = ( - self.tp_worker.forward_batch_generation(model_worker_batch) - ) - else: - ( - logits_output, - next_token_ids, - model_worker_batch, - num_accepted_tokens, - ) = self.draft_worker.forward_batch_speculative_generation(batch) - self.spec_num_total_accepted_tokens += ( - num_accepted_tokens + batch.batch_size() - ) - self.spec_num_total_forward_ct += batch.batch_size() - self.num_generated_tokens += num_accepted_tokens + if self.spec_algorithm.is_none(): + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + model_worker_batch + ) else: - assert False, "batch.extend_num_tokens == 0, this is unexpected!" + ( + logits_output, + next_token_ids, + model_worker_batch, + num_accepted_tokens, + ) = self.draft_worker.forward_batch_speculative_generation(batch) + self.spec_num_total_accepted_tokens += ( + num_accepted_tokens + batch.batch_size() + ) + self.spec_num_total_forward_ct += batch.batch_size() + self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids ret = GenerationBatchResult( @@ -1072,7 +1073,6 @@ class Scheduler: bid=model_worker_batch.bid, ) else: # embedding or reward model - assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) ret = EmbeddingBatchResult( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 69a5470be..90c2c15cb 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -18,7 +18,6 @@ suites = { "test_eagle_infer.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", - "test_get_weights_by_name.py", "test_gguf.py", "test_input_embeddings.py", "test_json_constrained.py", diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py index 97b6f7561..5e852bec6 100644 --- a/test/srt/test_ebnf_constrained.py +++ b/test/srt/test_ebnf_constrained.py @@ -236,12 +236,5 @@ class TestEBNFConstrained(unittest.TestCase): ) -class TestJumpForward(TestEBNFConstrained): - @classmethod - def setUpClass(cls): - setup_class(cls, disable_overlap=True) - cls.check_jump_forward = True - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 7c57c13e2..b4e71183d 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -5,6 +5,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_ import json import random +import time import unittest from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -317,12 +318,6 @@ class TestSRTEndpoint(unittest.TestCase): """Test custom logit processor with a single request.""" self.run_custom_logit_processor(target_token_id=5) - def test_custom_logit_processor_batch(self): - """Test custom logit processor with a batch of requests.""" - target_token_ids = list(range(32)) - with ThreadPoolExecutor(len(target_token_ids)) as executor: - list(executor.map(self.run_custom_logit_processor, target_token_ids)) - def test_custom_logit_processor_batch_mixed(self): """Test a batch of requests mixed of requests with and without custom logit processor.""" target_token_ids = list(range(32)) + [None] * 16 @@ -330,6 +325,31 @@ class TestSRTEndpoint(unittest.TestCase): with ThreadPoolExecutor(len(target_token_ids)) as executor: list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_cache_tokens(self): + for _ in range(2): + time.sleep(1) + response = requests.post(self.base_url + "/flush_cache") + assert response.status_code == 200 + + def send_and_check_cached_tokens(input_ids): + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": list(input_ids), + "sampling_params": { + "max_new_tokens": 1, + }, + }, + ) + response_json = response.json() + return response_json["meta_info"]["cached_tokens"] + + self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0) + self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100) + self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999) + self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999) + self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json()