Add a test case for cached_tokens (#3145)
This commit is contained in:
10
README.md
10
README.md
@@ -19,16 +19,16 @@
|
|||||||
| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
|
| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
|
||||||
|
|
||||||
## News
|
## News
|
||||||
- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeekSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html))
|
||||||
- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
||||||
- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
||||||
- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>More</summary>
|
<summary>More</summary>
|
||||||
|
|
||||||
|
- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
||||||
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
|
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
|
||||||
- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)).
|
|
||||||
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
|
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
|
||||||
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
|
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
|
||||||
|
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ class Req:
|
|||||||
|
|
||||||
# The number of cached tokens, that were already cached in the KV cache
|
# The number of cached tokens, that were already cached in the KV cache
|
||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
|
self.already_computed = 0
|
||||||
|
|
||||||
def extend_image_inputs(self, image_inputs):
|
def extend_image_inputs(self, image_inputs):
|
||||||
if self.image_inputs is None:
|
if self.image_inputs is None:
|
||||||
@@ -750,13 +751,6 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
already_computed = (
|
|
||||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
|
||||||
if req.extend_logprob_start_len > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
req.cached_tokens += len(req.prefix_indices) - already_computed
|
|
||||||
|
|
||||||
req.req_pool_idx = req_pool_indices[i]
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
@@ -772,15 +766,20 @@ class ScheduleBatch:
|
|||||||
# If req.input_embeds is already a list, append its content directly
|
# If req.input_embeds is already a list, append its content directly
|
||||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||||
|
|
||||||
# Compute the relative logprob_start_len in an extend batch
|
if req.return_logprob:
|
||||||
if req.logprob_start_len >= pre_len:
|
# Compute the relative logprob_start_len in an extend batch
|
||||||
extend_logprob_start_len = min(
|
if req.logprob_start_len >= pre_len:
|
||||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
extend_logprob_start_len = min(
|
||||||
)
|
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||||
else:
|
)
|
||||||
extend_logprob_start_len = req.extend_input_len - 1
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
||||||
|
)
|
||||||
|
req.extend_logprob_start_len = extend_logprob_start_len
|
||||||
|
|
||||||
req.extend_logprob_start_len = extend_logprob_start_len
|
req.cached_tokens += pre_len - req.already_computed
|
||||||
|
req.already_computed = seq_len
|
||||||
req.is_retracted = False
|
req.is_retracted = False
|
||||||
pre_lens.append(pre_len)
|
pre_lens.append(pre_len)
|
||||||
|
|
||||||
|
|||||||
@@ -660,24 +660,23 @@ class Scheduler:
|
|||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Copy more attributes
|
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
|
||||||
|
|
||||||
if req.logprob_start_len == -1:
|
|
||||||
# By default, only return the logprobs for output tokens
|
|
||||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
|
||||||
|
|
||||||
# Validate prompts length
|
# Validate prompts length
|
||||||
error_msg = validate_input_length(
|
error_msg = validate_input_length(
|
||||||
req,
|
req,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.server_args.allow_auto_truncate,
|
self.server_args.allow_auto_truncate,
|
||||||
)
|
)
|
||||||
|
|
||||||
if error_msg:
|
if error_msg:
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Copy more attributes
|
||||||
|
if recv_req.logprob_start_len == -1:
|
||||||
|
# By default, only return the logprobs for output tokens
|
||||||
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
|
else:
|
||||||
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
|
|
||||||
req.sampling_params.max_new_tokens = min(
|
req.sampling_params.max_new_tokens = min(
|
||||||
(
|
(
|
||||||
req.sampling_params.max_new_tokens
|
req.sampling_params.max_new_tokens
|
||||||
@@ -725,12 +724,17 @@ class Scheduler:
|
|||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
# Validate prompts length
|
# Validate prompts length
|
||||||
validate_input_length(
|
error_msg = validate_input_length(
|
||||||
req,
|
req,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.server_args.allow_auto_truncate,
|
self.server_args.allow_auto_truncate,
|
||||||
)
|
)
|
||||||
|
if error_msg:
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Copy more attributes
|
||||||
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
||||||
@@ -1044,26 +1048,23 @@ class Scheduler:
|
|||||||
self.forward_ct += 1
|
self.forward_ct += 1
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
|
if self.spec_algorithm.is_none():
|
||||||
if self.spec_algorithm.is_none():
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
logits_output, next_token_ids = (
|
model_worker_batch
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
logits_output,
|
|
||||||
next_token_ids,
|
|
||||||
model_worker_batch,
|
|
||||||
num_accepted_tokens,
|
|
||||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
|
||||||
self.spec_num_total_accepted_tokens += (
|
|
||||||
num_accepted_tokens + batch.batch_size()
|
|
||||||
)
|
|
||||||
self.spec_num_total_forward_ct += batch.batch_size()
|
|
||||||
self.num_generated_tokens += num_accepted_tokens
|
|
||||||
else:
|
else:
|
||||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
(
|
||||||
|
logits_output,
|
||||||
|
next_token_ids,
|
||||||
|
model_worker_batch,
|
||||||
|
num_accepted_tokens,
|
||||||
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||||
|
self.spec_num_total_accepted_tokens += (
|
||||||
|
num_accepted_tokens + batch.batch_size()
|
||||||
|
)
|
||||||
|
self.spec_num_total_forward_ct += batch.batch_size()
|
||||||
|
self.num_generated_tokens += num_accepted_tokens
|
||||||
batch.output_ids = next_token_ids
|
batch.output_ids = next_token_ids
|
||||||
|
|
||||||
ret = GenerationBatchResult(
|
ret = GenerationBatchResult(
|
||||||
@@ -1072,7 +1073,6 @@ class Scheduler:
|
|||||||
bid=model_worker_batch.bid,
|
bid=model_worker_batch.bid,
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
ret = EmbeddingBatchResult(
|
ret = EmbeddingBatchResult(
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ suites = {
|
|||||||
"test_eagle_infer.py",
|
"test_eagle_infer.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
"test_get_weights_by_name.py",
|
|
||||||
"test_gguf.py",
|
"test_gguf.py",
|
||||||
"test_input_embeddings.py",
|
"test_input_embeddings.py",
|
||||||
"test_json_constrained.py",
|
"test_json_constrained.py",
|
||||||
|
|||||||
@@ -236,12 +236,5 @@ class TestEBNFConstrained(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestJumpForward(TestEBNFConstrained):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, disable_overlap=True)
|
|
||||||
cls.check_jump_forward = True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -317,12 +318,6 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
"""Test custom logit processor with a single request."""
|
"""Test custom logit processor with a single request."""
|
||||||
self.run_custom_logit_processor(target_token_id=5)
|
self.run_custom_logit_processor(target_token_id=5)
|
||||||
|
|
||||||
def test_custom_logit_processor_batch(self):
|
|
||||||
"""Test custom logit processor with a batch of requests."""
|
|
||||||
target_token_ids = list(range(32))
|
|
||||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
|
||||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
|
||||||
|
|
||||||
def test_custom_logit_processor_batch_mixed(self):
|
def test_custom_logit_processor_batch_mixed(self):
|
||||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||||
target_token_ids = list(range(32)) + [None] * 16
|
target_token_ids = list(range(32)) + [None] * 16
|
||||||
@@ -330,6 +325,31 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||||
|
|
||||||
|
def test_cache_tokens(self):
|
||||||
|
for _ in range(2):
|
||||||
|
time.sleep(1)
|
||||||
|
response = requests.post(self.base_url + "/flush_cache")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def send_and_check_cached_tokens(input_ids):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": list(input_ids),
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
return response_json["meta_info"]["cached_tokens"]
|
||||||
|
|
||||||
|
self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
|
||||||
|
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
|
||||||
|
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
|
||||||
|
self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
|
||||||
|
self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)
|
||||||
|
|
||||||
def test_get_server_info(self):
|
def test_get_server_info(self):
|
||||||
response = requests.get(self.base_url + "/get_server_info")
|
response = requests.get(self.base_url + "/get_server_info")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|||||||
Reference in New Issue
Block a user