Further fix memory pool leak error (#9298)
This commit is contained in:
@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
need_sort: bool,
|
||||
max_num_extend_tokens: int,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.num_pages = size // page_size
|
||||
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
||||
max_num_extend_tokens
|
||||
)
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
||||
self.clear()
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
self.seen_max_num_extend_tokens_next_power_of_2 = max(
|
||||
self.seen_max_num_extend_tokens_next_power_of_2,
|
||||
next_power_of_2(extend_num_tokens),
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
||||
self.free_pages
|
||||
):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
assert self.max_num_extend_tokens_next_power_of_2 >= extend_num_tokens, (
|
||||
f"{self.max_num_extend_tokens_next_power_of_2=} >= {extend_num_tokens=} does not hold. "
|
||||
f"If this happens in PD, consider letting chunked_prefill_size in D be as large as in P"
|
||||
)
|
||||
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
self.max_num_extend_tokens_next_power_of_2,
|
||||
self.seen_max_num_extend_tokens_next_power_of_2,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
|
||||
@@ -1353,11 +1353,6 @@ class ModelRunner:
|
||||
|
||||
# Initialize token_to_kv_pool_allocator
|
||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||
max_num_extend_tokens = (
|
||||
self.server_args.chunked_prefill_size
|
||||
if self.server_args.chunked_prefill_size > 0
|
||||
else self.server_args.max_prefill_tokens
|
||||
)
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
@@ -1396,7 +1391,6 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
max_num_extend_tokens=max_num_extend_tokens,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
Reference in New Issue
Block a user