Fix prefill OOM error in the case of large page size (#5081)
This commit is contained in:
@@ -455,7 +455,10 @@ class PrefillAdder:
|
|||||||
total_tokens = req.extend_input_len + min(
|
total_tokens = req.extend_input_len + min(
|
||||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||||
)
|
)
|
||||||
input_tokens = req.extend_input_len
|
input_tokens = (
|
||||||
|
-(-req.extend_input_len // self.tree_cache.page_size)
|
||||||
|
* self.tree_cache.page_size
|
||||||
|
)
|
||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
if total_tokens >= self.rem_total_tokens:
|
if total_tokens >= self.rem_total_tokens:
|
||||||
@@ -477,7 +480,10 @@ class PrefillAdder:
|
|||||||
req.last_node_global, req.prefix_indices
|
req.last_node_global, req.prefix_indices
|
||||||
)
|
)
|
||||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||||
input_tokens = req.extend_input_len
|
input_tokens = (
|
||||||
|
-(-req.extend_input_len // self.tree_cache.page_size)
|
||||||
|
* self.tree_cache.page_size
|
||||||
|
)
|
||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||||
|
|||||||
@@ -502,6 +502,7 @@ class Scheduler(
|
|||||||
self.tree_cache = ChunkCache(
|
self.tree_cache = ChunkCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
page_size=self.page_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
|
|||||||
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
|
|||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
|
page_size: int,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
|
self.page_size = page_size
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user