Fix memory leak when doing chunked prefill (#1787)
This commit is contained in:
@@ -15,7 +15,7 @@ class GlobalConfig:
|
|||||||
|
|
||||||
# Runtime constants: New generation token ratio estimation
|
# Runtime constants: New generation token ratio estimation
|
||||||
self.init_new_token_ratio = 0.7
|
self.init_new_token_ratio = 0.7
|
||||||
self.base_min_new_token_ratio = 0.1
|
self.min_new_token_ratio = 0.1
|
||||||
self.new_token_ratio_decay = 0.001
|
self.new_token_ratio_decay = 0.001
|
||||||
|
|
||||||
# Runtime constants: others
|
# Runtime constants: others
|
||||||
@@ -32,5 +32,15 @@ class GlobalConfig:
|
|||||||
self.enable_precache_with_tracing = True
|
self.enable_precache_with_tracing = True
|
||||||
self.enable_parallel_encoding = True
|
self.enable_parallel_encoding = True
|
||||||
|
|
||||||
|
def adjust_new_token_ratio(self, schedule_conservativeness=1):
|
||||||
|
assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness"
|
||||||
|
min_new_token_ratio = min(
|
||||||
|
self.min_new_token_ratio * schedule_conservativeness,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
init_new_token_ratio = max(self.init_new_token_ratio, min_new_token_ratio)
|
||||||
|
|
||||||
|
return min_new_token_ratio, init_new_token_ratio
|
||||||
|
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class Req:
|
|||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
self.is_inflight_req = 0
|
self.is_being_chunked = False
|
||||||
|
|
||||||
# Logprobs (arguments)
|
# Logprobs (arguments)
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
@@ -906,15 +906,14 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
current_inflight_req: Optional[Req] = None,
|
being_chunked_req: Optional[Req] = None,
|
||||||
keep_indices: Optional[List[int]] = None,
|
keep_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if keep_indices is None:
|
if keep_indices is None:
|
||||||
keep_indices = [
|
keep_indices = [
|
||||||
i
|
i
|
||||||
for i in range(len(self.reqs))
|
for i in range(len(self.reqs))
|
||||||
if not self.reqs[i].finished()
|
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
||||||
and self.reqs[i] is not current_inflight_req
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if keep_indices is None or len(keep_indices) == 0:
|
if keep_indices is None or len(keep_indices) == 0:
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class PrefillAdder:
|
|||||||
|
|
||||||
self.req_states = None
|
self.req_states = None
|
||||||
self.can_run_list = []
|
self.can_run_list = []
|
||||||
self.new_inflight_req = None
|
self.new_chunked_req = None
|
||||||
self.log_hit_tokens = 0
|
self.log_hit_tokens = 0
|
||||||
self.log_input_tokens = 0
|
self.log_input_tokens = 0
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ class PrefillAdder:
|
|||||||
self.log_hit_tokens += prefix_len
|
self.log_hit_tokens += prefix_len
|
||||||
self.log_input_tokens += extend_input_len
|
self.log_input_tokens += extend_input_len
|
||||||
|
|
||||||
def add_inflight_req(self, req: Req):
|
def add_being_chunked_req(self, req: Req):
|
||||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
@@ -192,8 +192,13 @@ class PrefillAdder:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return if chunked prefill not finished
|
if truncated:
|
||||||
return req if truncated else None
|
# Continue to chunk the request
|
||||||
|
assert req.is_being_chunked
|
||||||
|
self.new_chunked_req = req
|
||||||
|
else:
|
||||||
|
# Release the being chunked status
|
||||||
|
req.is_being_chunked = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _lock_node(self, last_node: TreeNode):
|
def _lock_node(self, last_node: TreeNode):
|
||||||
@@ -262,11 +267,14 @@ class PrefillAdder:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
|
assert self.new_chunked_req is None
|
||||||
|
|
||||||
trunc_len = self.rem_chunk_tokens
|
trunc_len = self.rem_chunk_tokens
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
|
req.is_being_chunked = True
|
||||||
req.fill_ids = req.fill_ids[:trunc_len]
|
req.fill_ids = req.fill_ids[:trunc_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_inflight_req = req
|
self.new_chunked_req = req
|
||||||
self._prefill_one_req(0, trunc_len, 0)
|
self._prefill_one_req(0, trunc_len, 0)
|
||||||
|
|
||||||
return self.budget_state()
|
return self.budget_state()
|
||||||
@@ -305,15 +313,18 @@ class PrefillAdder:
|
|||||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Chunked prefill
|
|
||||||
trunc_len = self.rem_chunk_tokens
|
trunc_len = self.rem_chunk_tokens
|
||||||
if trunc_len == 0:
|
if trunc_len == 0:
|
||||||
return AddReqResult.OTHER
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
|
# Chunked prefill
|
||||||
|
assert self.new_chunked_req is None
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
|
req.is_being_chunked = True
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_inflight_req = req
|
self.new_chunked_req = req
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
|
|||||||
@@ -219,35 +219,28 @@ class Scheduler:
|
|||||||
|
|
||||||
# Init chunked prefill
|
# Init chunked prefill
|
||||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
self.current_inflight_req = None
|
self.being_chunked_req = None
|
||||||
self.is_mixed_chunk = (
|
self.is_mixed_chunk = (
|
||||||
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
if not server_args.skip_tokenizer_init:
|
self.regex_fsm_cache = FSMCache(
|
||||||
self.regex_fsm_cache = FSMCache(
|
server_args.tokenizer_path,
|
||||||
server_args.tokenizer_path,
|
{
|
||||||
{
|
"tokenizer_mode": server_args.tokenizer_mode,
|
||||||
"tokenizer_mode": server_args.tokenizer_mode,
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
"trust_remote_code": server_args.trust_remote_code,
|
},
|
||||||
},
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
)
|
||||||
)
|
|
||||||
self.jump_forward_cache = JumpForwardCache()
|
self.jump_forward_cache = JumpForwardCache()
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
assert (
|
self.min_new_token_ratio, self.init_new_token_ratio = (
|
||||||
server_args.schedule_conservativeness >= 0
|
global_config.adjust_new_token_ratio(server_args.schedule_conservativeness)
|
||||||
), "Invalid schedule_conservativeness"
|
|
||||||
self.min_new_token_ratio = min(
|
|
||||||
global_config.base_min_new_token_ratio
|
|
||||||
* server_args.schedule_conservativeness,
|
|
||||||
1.0,
|
|
||||||
)
|
)
|
||||||
self.new_token_ratio = self.min_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
# Init profiler
|
# Init profiler
|
||||||
@@ -294,7 +287,7 @@ class Scheduler:
|
|||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
else:
|
else:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -321,7 +314,7 @@ class Scheduler:
|
|||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -499,20 +492,18 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
exit(1) if crash_on_warning else None
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
def get_next_batch_to_run(self):
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||||
# Merge the prefill batch into the running batch
|
# Merge the prefill batch into the running batch
|
||||||
if (
|
if (
|
||||||
self.last_batch
|
self.last_batch
|
||||||
and not self.last_batch.forward_mode.is_decode()
|
and not self.last_batch.forward_mode.is_decode()
|
||||||
and not self.last_batch.is_empty()
|
and not self.last_batch.is_empty()
|
||||||
):
|
):
|
||||||
if self.current_inflight_req:
|
if self.being_chunked_req:
|
||||||
self.last_batch.filter_batch(
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
||||||
current_inflight_req=self.current_inflight_req
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
||||||
)
|
# Being chunked request keeps its rid but will get a new req_pool_idx.
|
||||||
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
||||||
# Inflight request keeps its rid but will get a new req_pool_idx.
|
|
||||||
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
if not self.last_batch.is_empty():
|
if not self.last_batch.is_empty():
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
@@ -543,7 +534,7 @@ class Scheduler:
|
|||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
if (
|
if (
|
||||||
self.batch_is_full or len(self.waiting_queue) == 0
|
self.batch_is_full or len(self.waiting_queue) == 0
|
||||||
) and self.current_inflight_req is None:
|
) and self.being_chunked_req is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||||
@@ -566,15 +557,6 @@ class Scheduler:
|
|||||||
num_mixed_running,
|
num_mixed_running,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_inflight = self.current_inflight_req is not None
|
|
||||||
if has_inflight:
|
|
||||||
self.current_inflight_req.init_next_round_input(
|
|
||||||
None if prefix_computed else self.tree_cache
|
|
||||||
)
|
|
||||||
self.current_inflight_req = adder.add_inflight_req(
|
|
||||||
self.current_inflight_req
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.lora_paths:
|
if self.lora_paths:
|
||||||
lora_set = (
|
lora_set = (
|
||||||
set([req.lora_path for req in self.running_batch.reqs])
|
set([req.lora_path for req in self.running_batch.reqs])
|
||||||
@@ -582,6 +564,13 @@ class Scheduler:
|
|||||||
else set([])
|
else set([])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE: if there is request being chunked, we always add it first
|
||||||
|
has_being_chunked = self.being_chunked_req is not None
|
||||||
|
if has_being_chunked:
|
||||||
|
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
|
||||||
|
self.being_chunked_req.init_next_round_input()
|
||||||
|
adder.add_being_chunked_req(self.being_chunked_req)
|
||||||
|
|
||||||
# Get requests from the waiting queue to a new prefill batch
|
# Get requests from the waiting queue to a new prefill batch
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
if (
|
if (
|
||||||
@@ -615,12 +604,8 @@ class Scheduler:
|
|||||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
]
|
]
|
||||||
|
|
||||||
if adder.new_inflight_req is not None:
|
# Update new round being chunked request
|
||||||
assert self.current_inflight_req is None
|
self.being_chunked_req = adder.new_chunked_req
|
||||||
self.current_inflight_req = adder.new_inflight_req
|
|
||||||
|
|
||||||
if self.current_inflight_req:
|
|
||||||
self.current_inflight_req.is_inflight_req += 1
|
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
@@ -649,7 +634,7 @@ class Scheduler:
|
|||||||
f"#cached-token: {adder.log_hit_tokens}, "
|
f"#cached-token: {adder.log_hit_tokens}, "
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -660,7 +645,7 @@ class Scheduler:
|
|||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"#running-req: {running_bs}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new batch
|
# Create a new batch
|
||||||
@@ -709,7 +694,7 @@ class Scheduler:
|
|||||||
self.waiting_queue.extend(retracted_reqs)
|
self.waiting_queue.extend(retracted_reqs)
|
||||||
else:
|
else:
|
||||||
self.new_token_ratio = max(
|
self.new_token_ratio = max(
|
||||||
self.new_token_ratio - self.new_token_ratio_decay,
|
self.new_token_ratio - global_config.new_token_ratio_decay,
|
||||||
self.min_new_token_ratio,
|
self.min_new_token_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -783,10 +768,8 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req.is_inflight_req > 0:
|
if not req.is_being_chunked:
|
||||||
req.is_inflight_req -= 1
|
# Being chunked reqs' prefill is not finished
|
||||||
else:
|
|
||||||
# Inflight reqs' prefill is not finished
|
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_ids[i])
|
req.output_ids.append(next_token_ids[i])
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
@@ -812,10 +795,8 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.embedding = embeddings[i]
|
req.embedding = embeddings[i]
|
||||||
if req.is_inflight_req > 0:
|
if not req.is_being_chunked:
|
||||||
req.is_inflight_req -= 1
|
# Being chunked reqs' prefill is not finished
|
||||||
else:
|
|
||||||
# Inflight reqs' prefill is not finished
|
|
||||||
# dummy output token for embedding models
|
# dummy output token for embedding models
|
||||||
req.output_ids.append(0)
|
req.output_ids.append(0)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|||||||
@@ -660,6 +660,7 @@ def run_mmlu_test(
|
|||||||
chunked_prefill_size=32,
|
chunked_prefill_size=32,
|
||||||
):
|
):
|
||||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||||
|
other_args += ["--mem-fraction-static", "0.85"]
|
||||||
if disable_radix_cache:
|
if disable_radix_cache:
|
||||||
other_args += ["--disable-radix-cache"]
|
other_args += ["--disable-radix-cache"]
|
||||||
if enable_mixed_chunk:
|
if enable_mixed_chunk:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from sglang.test.test_utils import run_unittest_files
|
|||||||
|
|
||||||
suites = {
|
suites = {
|
||||||
"minimal": [
|
"minimal": [
|
||||||
|
"test_radix_attention.py",
|
||||||
"models/test_embedding_models.py",
|
"models/test_embedding_models.py",
|
||||||
"models/test_generation_models.py",
|
"models/test_generation_models.py",
|
||||||
"models/test_lora.py",
|
"models/test_lora.py",
|
||||||
|
|||||||
112
test/srt/test_radix_attention.py
Normal file
112
test/srt/test_radix_attention.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
kill_child_process,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_radix_tree(num_nodes=400, chunk_len=256):
|
||||||
|
num0 = num_nodes // 2
|
||||||
|
num1 = num_nodes - num0
|
||||||
|
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
|
||||||
|
for _ in range(num0):
|
||||||
|
parent = random.choice(nodes)
|
||||||
|
unique_len = random.randint(0, chunk_len)
|
||||||
|
decode_len = random.randint(0, chunk_len)
|
||||||
|
token_id = random.randint(0, 32000)
|
||||||
|
child = {
|
||||||
|
"input_ids": parent["input_ids"] + [token_id] * unique_len,
|
||||||
|
"decode_len": decode_len,
|
||||||
|
}
|
||||||
|
nodes.append(child)
|
||||||
|
|
||||||
|
while num1 > 0:
|
||||||
|
num_branch = random.randint(1, min(num1, 10))
|
||||||
|
parent = random.choice(nodes)
|
||||||
|
for _ in range(num_branch):
|
||||||
|
unique_len = random.randint(0, chunk_len)
|
||||||
|
decode_len = random.randint(0, chunk_len)
|
||||||
|
token_id = random.randint(0, 32000)
|
||||||
|
child = {
|
||||||
|
"input_ids": parent["input_ids"] + [token_id] * unique_len,
|
||||||
|
"decode_len": decode_len,
|
||||||
|
}
|
||||||
|
nodes.append(child)
|
||||||
|
|
||||||
|
num1 -= num_branch
|
||||||
|
|
||||||
|
random.shuffle(nodes)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(base_url, nodes):
|
||||||
|
data = {
|
||||||
|
"input_ids": [node["input_ids"] for node in nodes],
|
||||||
|
"sampling_params": [
|
||||||
|
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
res = requests.post(base_url + "/generate", json=data)
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestRadixCacheFCFS(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"20000",
|
||||||
|
"--schedule-policy",
|
||||||
|
"fcfs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def test_radix_attention(self):
|
||||||
|
nodes = gen_radix_tree()
|
||||||
|
run_test(self.base_url, nodes)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRadixCacheLPM(TestRadixCacheFCFS):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
|
"--max-total-tokens",
|
||||||
|
"20000",
|
||||||
|
"--schedule-policy",
|
||||||
|
"lpm",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.environ["SGLANG_TEST_RETRACT"] = "true"
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user