Fix memory leak for chunked prefill 2 (#1858)
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
6
.github/workflows/pr-test.yml
vendored
6
.github/workflows/pr-test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5
|
python3 run_suite.py --suite minimal --range-begin 0 --range-end 4
|
||||||
|
|
||||||
unit-test-backend-part-2:
|
unit-test-backend-part-2:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
|
python3 run_suite.py --suite minimal --range-begin 4 --range-end 14
|
||||||
|
|
||||||
unit-test-backend-part-3:
|
unit-test-backend-part-3:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -84,7 +84,7 @@ jobs:
|
|||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 17 --range-end 20
|
python3 run_suite.py --suite minimal --range-begin 14 --range-end 20
|
||||||
|
|
||||||
unit-test-backend-part-4:
|
unit-test-backend-part-4:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# Guide on Hyperparameter Tuning
|
# Guide on Hyperparameter Tuning
|
||||||
|
|
||||||
## Achieving Peak Throughput
|
## Achieving Peak Throughput
|
||||||
|
|
||||||
Achieving a large batch size is the most important thing for attaining high throughput.
|
Achieving a large batch size is the most important thing for attaining high throughput.
|
||||||
|
|
||||||
When the server is running at full load, look for the following in the log:
|
When the server is running at full load, look for the following in the log:
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class Req:
|
|||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
self.is_inflight_req = 0
|
self.is_being_chunked = 0
|
||||||
|
|
||||||
# Logprobs (arguments)
|
# Logprobs (arguments)
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
@@ -888,7 +888,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
current_inflight_req: Optional[Req] = None,
|
being_chunked_req: Optional[Req] = None,
|
||||||
keep_indices: Optional[List[int]] = None,
|
keep_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if keep_indices is None:
|
if keep_indices is None:
|
||||||
@@ -896,7 +896,7 @@ class ScheduleBatch:
|
|||||||
i
|
i
|
||||||
for i in range(len(self.reqs))
|
for i in range(len(self.reqs))
|
||||||
if not self.reqs[i].finished()
|
if not self.reqs[i].finished()
|
||||||
and self.reqs[i] is not current_inflight_req
|
and self.reqs[i] is not being_chunked_req
|
||||||
]
|
]
|
||||||
|
|
||||||
if keep_indices is None or len(keep_indices) == 0:
|
if keep_indices is None or len(keep_indices) == 0:
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Init chunked prefill
|
# Init chunked prefill
|
||||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
self.current_inflight_req = None
|
self.being_chunked_req = None
|
||||||
self.is_mixed_chunk = (
|
self.is_mixed_chunk = (
|
||||||
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
||||||
)
|
)
|
||||||
@@ -551,13 +551,13 @@ class Scheduler:
|
|||||||
and not self.last_batch.forward_mode.is_decode()
|
and not self.last_batch.forward_mode.is_decode()
|
||||||
and not self.last_batch.is_empty()
|
and not self.last_batch.is_empty()
|
||||||
):
|
):
|
||||||
if self.current_inflight_req:
|
if self.being_chunked_req:
|
||||||
self.last_batch.filter_batch(
|
self.last_batch.filter_batch(
|
||||||
current_inflight_req=self.current_inflight_req
|
being_chunked_req=self.being_chunked_req
|
||||||
)
|
)
|
||||||
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
||||||
# Inflight request keeps its rid but will get a new req_pool_idx.
|
# Inflight request keeps its rid but will get a new req_pool_idx.
|
||||||
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
if not self.last_batch.is_empty():
|
if not self.last_batch.is_empty():
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
@@ -588,7 +588,7 @@ class Scheduler:
|
|||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
if (
|
if (
|
||||||
self.batch_is_full or len(self.waiting_queue) == 0
|
self.batch_is_full or len(self.waiting_queue) == 0
|
||||||
) and self.current_inflight_req is None:
|
) and self.being_chunked_req is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||||
@@ -611,13 +611,11 @@ class Scheduler:
|
|||||||
num_mixed_running,
|
num_mixed_running,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_inflight = self.current_inflight_req is not None
|
has_inflight = self.being_chunked_req is not None
|
||||||
if has_inflight:
|
if has_inflight:
|
||||||
self.current_inflight_req.init_next_round_input(
|
self.being_chunked_req.init_next_round_input()
|
||||||
None if prefix_computed else self.tree_cache
|
self.being_chunked_req = adder.add_inflight_req(
|
||||||
)
|
self.being_chunked_req
|
||||||
self.current_inflight_req = adder.add_inflight_req(
|
|
||||||
self.current_inflight_req
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.lora_paths:
|
if self.lora_paths:
|
||||||
@@ -661,11 +659,11 @@ class Scheduler:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if adder.new_inflight_req is not None:
|
if adder.new_inflight_req is not None:
|
||||||
assert self.current_inflight_req is None
|
assert self.being_chunked_req is None
|
||||||
self.current_inflight_req = adder.new_inflight_req
|
self.being_chunked_req = adder.new_inflight_req
|
||||||
|
|
||||||
if self.current_inflight_req:
|
if self.being_chunked_req:
|
||||||
self.current_inflight_req.is_inflight_req += 1
|
self.being_chunked_req.is_being_chunked += 1
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
@@ -833,8 +831,8 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req.is_inflight_req > 0:
|
if req.is_being_chunked > 0:
|
||||||
req.is_inflight_req -= 1
|
req.is_being_chunked -= 1
|
||||||
else:
|
else:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
@@ -860,8 +858,8 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.embedding = embeddings[i]
|
req.embedding = embeddings[i]
|
||||||
if req.is_inflight_req > 0:
|
if req.is_being_chunked > 0:
|
||||||
req.is_inflight_req -= 1
|
req.is_being_chunked -= 1
|
||||||
else:
|
else:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
# dummy output token for embedding models
|
# dummy output token for embedding models
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
# Kill all SGLang processes and free the GPU memory.
|
||||||
Kill all SGLang processes and free the GPU memory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
||||||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ suites = {
|
|||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
"test_overlap_schedule.py",
|
"test_overlap_schedule.py",
|
||||||
"test_pytorch_sampling_backend.py",
|
"test_pytorch_sampling_backend.py",
|
||||||
|
"test_radix_attention.py",
|
||||||
"test_retract_decode.py",
|
"test_retract_decode.py",
|
||||||
"test_server_args.py",
|
"test_server_args.py",
|
||||||
"test_skip_tokenizer_init.py",
|
"test_skip_tokenizer_init.py",
|
||||||
|
|||||||
112
test/srt/test_radix_attention.py
Normal file
112
test/srt/test_radix_attention.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
kill_child_process,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_radix_tree(num_nodes=400, chunk_len=256):
|
||||||
|
num0 = num_nodes // 2
|
||||||
|
num1 = num_nodes - num0
|
||||||
|
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
|
||||||
|
for _ in range(num0):
|
||||||
|
parent = random.choice(nodes)
|
||||||
|
unique_len = random.randint(0, chunk_len)
|
||||||
|
decode_len = random.randint(0, chunk_len)
|
||||||
|
token_id = random.randint(0, 32000)
|
||||||
|
child = {
|
||||||
|
"input_ids": parent["input_ids"] + [token_id] * unique_len,
|
||||||
|
"decode_len": decode_len,
|
||||||
|
}
|
||||||
|
nodes.append(child)
|
||||||
|
|
||||||
|
while num1 > 0:
|
||||||
|
num_branch = random.randint(1, min(num1, 10))
|
||||||
|
parent = random.choice(nodes)
|
||||||
|
for _ in range(num_branch):
|
||||||
|
unique_len = random.randint(0, chunk_len)
|
||||||
|
decode_len = random.randint(0, chunk_len)
|
||||||
|
token_id = random.randint(0, 32000)
|
||||||
|
child = {
|
||||||
|
"input_ids": parent["input_ids"] + [token_id] * unique_len,
|
||||||
|
"decode_len": decode_len,
|
||||||
|
}
|
||||||
|
nodes.append(child)
|
||||||
|
|
||||||
|
num1 -= num_branch
|
||||||
|
|
||||||
|
random.shuffle(nodes)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(base_url, nodes):
|
||||||
|
data = {
|
||||||
|
"input_ids": [node["input_ids"] for node in nodes],
|
||||||
|
"sampling_params": [
|
||||||
|
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
res = requests.post(base_url + "/generate", json=data)
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestRadixCacheFCFS(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"20000",
|
||||||
|
"--schedule-policy",
|
||||||
|
"fcfs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
|
def test_radix_attention(self):
|
||||||
|
nodes = gen_radix_tree()
|
||||||
|
run_test(self.base_url, nodes)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRadixCacheLPM(TestRadixCacheFCFS):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"20000",
|
||||||
|
"--schedule-policy",
|
||||||
|
"lpm",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.environ["SGLANG_TEST_RETRACT"] = "true"
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user