Further fix memory pool leak error (#9298)
This commit is contained in:
@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
device: str,
|
device: str,
|
||||||
kvcache: KVCache,
|
kvcache: KVCache,
|
||||||
need_sort: bool,
|
need_sort: bool,
|
||||||
max_num_extend_tokens: int,
|
|
||||||
):
|
):
|
||||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||||
self.num_pages = size // page_size
|
self.num_pages = size // page_size
|
||||||
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
|
||||||
max_num_extend_tokens
|
|
||||||
)
|
|
||||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||||
|
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int):
|
||||||
@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.seen_max_num_extend_tokens_next_power_of_2 = max(
|
||||||
|
self.seen_max_num_extend_tokens_next_power_of_2,
|
||||||
|
next_power_of_2(extend_num_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
bs = len(prefix_lens)
|
bs = len(prefix_lens)
|
||||||
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
||||||
self.free_pages
|
self.free_pages
|
||||||
):
|
):
|
||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
assert self.max_num_extend_tokens_next_power_of_2 >= extend_num_tokens, (
|
|
||||||
f"{self.max_num_extend_tokens_next_power_of_2=} >= {extend_num_tokens=} does not hold. "
|
|
||||||
f"If this happens in PD, consider letting chunked_prefill_size in D be as large as in P"
|
|
||||||
)
|
|
||||||
|
|
||||||
out_indices = torch.empty(
|
out_indices = torch.empty(
|
||||||
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
||||||
)
|
)
|
||||||
@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
self.ret_values,
|
self.ret_values,
|
||||||
next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
self.page_size,
|
self.page_size,
|
||||||
self.max_num_extend_tokens_next_power_of_2,
|
self.seen_max_num_extend_tokens_next_power_of_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
|
|||||||
@@ -1353,11 +1353,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Initialize token_to_kv_pool_allocator
|
# Initialize token_to_kv_pool_allocator
|
||||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||||
max_num_extend_tokens = (
|
|
||||||
self.server_args.chunked_prefill_size
|
|
||||||
if self.server_args.chunked_prefill_size > 0
|
|
||||||
else self.server_args.max_prefill_tokens
|
|
||||||
)
|
|
||||||
if self.token_to_kv_pool_allocator is None:
|
if self.token_to_kv_pool_allocator is None:
|
||||||
if self.server_args.attention_backend == "ascend":
|
if self.server_args.attention_backend == "ascend":
|
||||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||||
@@ -1396,7 +1391,6 @@ class ModelRunner:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
kvcache=self.token_to_kv_pool,
|
kvcache=self.token_to_kv_pool,
|
||||||
need_sort=need_sort,
|
need_sort=need_sort,
|
||||||
max_num_extend_tokens=max_num_extend_tokens,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.is_draft_worker
|
assert self.is_draft_worker
|
||||||
|
|||||||
Reference in New Issue
Block a user