Fix no-cache mode (#136)
This commit is contained in:
@@ -215,8 +215,9 @@ class Batch:
|
|||||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
if not self.tree_cache.disable:
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||||
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
print("Prefill out of memory. This should nerver happen.")
|
print("Prefill out of memory. This should nerver happen.")
|
||||||
@@ -277,11 +278,11 @@ class Batch:
|
|||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self):
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
avai_size = self.token_to_kv_pool.available_size()
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
if avai_size >= bs:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
if not self.tree_cache.disable:
|
||||||
|
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user