Make req_pool_indices on CPU (#960)
This commit is contained in:
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|||||||
self.init_new_token_ratio = 0.7
|
self.init_new_token_ratio = 0.7
|
||||||
self.base_min_new_token_ratio = 0.1
|
self.base_min_new_token_ratio = 0.1
|
||||||
self.new_token_ratio_decay = 0.001
|
self.new_token_ratio_decay = 0.001
|
||||||
self.new_token_ratio_recovery = 0.05
|
|
||||||
|
|
||||||
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||||
# This can improve the speed for large batch sizes during prefill.
|
# This can improve the speed for large batch sizes during prefill.
|
||||||
|
|||||||
@@ -100,6 +100,9 @@ class Req:
|
|||||||
self.output_ids = [] # Each decode stage's output ids
|
self.output_ids = [] # Each decode stage's output ids
|
||||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||||
|
|
||||||
|
# Memory info
|
||||||
|
self.req_pool_idx = None
|
||||||
|
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
# ----- | --------- read_ids -------|
|
# ----- | --------- read_ids -------|
|
||||||
# ----- | surr_ids |
|
# ----- | surr_ids |
|
||||||
@@ -321,6 +324,9 @@ class ScheduleBatch:
|
|||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def batch_size(self):
|
||||||
|
return len(self.reqs) if self.reqs is not None else 0
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|
||||||
@@ -328,52 +334,22 @@ class ScheduleBatch:
|
|||||||
# Return whether batch has at least 1 streaming request
|
# Return whether batch has at least 1 streaming request
|
||||||
return any(r.stream for r in self.reqs)
|
return any(r.stream for r in self.reqs)
|
||||||
|
|
||||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
def alloc_req_slots(self, num_reqs):
|
||||||
device = "cuda"
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||||
bs = len(self.reqs)
|
|
||||||
reqs = self.reqs
|
|
||||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
|
||||||
prefix_indices = [r.prefix_indices for r in reqs]
|
|
||||||
|
|
||||||
# Handle prefix
|
|
||||||
flatten_input_ids = []
|
|
||||||
extend_lens = []
|
|
||||||
prefix_lens = []
|
|
||||||
seq_lens = []
|
|
||||||
|
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
|
||||||
|
|
||||||
if req_pool_indices is None:
|
if req_pool_indices is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Out of memory. "
|
"Out of memory. "
|
||||||
"Please set a smaller number for `--max-running-requests`."
|
"Please set a smaller number for `--max-running-requests`."
|
||||||
)
|
)
|
||||||
|
return req_pool_indices
|
||||||
|
|
||||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
def alloc_token_slots(self, num_tokens: int):
|
||||||
for i in range(bs):
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||||
flatten_input_ids.extend(input_ids[i])
|
|
||||||
extend_lens.append(len(input_ids[i]))
|
|
||||||
|
|
||||||
if len(prefix_indices[i]) == 0:
|
|
||||||
prefix_lens.append(0)
|
|
||||||
else:
|
|
||||||
prefix_lens.append(len(prefix_indices[i]))
|
|
||||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
|
||||||
: len(prefix_indices[i])
|
|
||||||
] = prefix_indices[i]
|
|
||||||
|
|
||||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
|
||||||
|
|
||||||
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
# Allocate memory
|
|
||||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
|
||||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
if self.tree_cache is not None:
|
if self.tree_cache is not None:
|
||||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||||
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
logger.error("Prefill out of memory. Try to lower your batch size.")
|
logger.error("Prefill out of memory. Try to lower your batch size.")
|
||||||
@@ -381,40 +357,11 @@ class ScheduleBatch:
|
|||||||
self.tree_cache.pretty_print()
|
self.tree_cache.pretty_print()
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
pt = 0
|
return out_cache_loc
|
||||||
for i in range(bs):
|
|
||||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
|
||||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
|
||||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
|
||||||
pt += extend_lens[i]
|
|
||||||
|
|
||||||
# Handle logit bias but only allocate when needed
|
|
||||||
logit_bias = None
|
|
||||||
for i in range(bs):
|
|
||||||
if reqs[i].sampling_params.dtype == "int":
|
|
||||||
if logit_bias is None:
|
|
||||||
logit_bias = torch.zeros(
|
|
||||||
(bs, vocab_size), dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
|
||||||
|
|
||||||
# Set fields
|
|
||||||
self.input_ids = torch.tensor(
|
|
||||||
flatten_input_ids, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
self.pixel_values = [r.pixel_values for r in reqs]
|
|
||||||
self.image_sizes = [r.image_size for r in reqs]
|
|
||||||
self.image_offsets = [
|
|
||||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
|
||||||
]
|
|
||||||
self.req_pool_indices = req_pool_indices
|
|
||||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
|
||||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
|
||||||
self.position_ids_offsets = position_ids_offsets
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
|
||||||
self.out_cache_loc = out_cache_loc
|
|
||||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
|
||||||
|
|
||||||
|
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
||||||
|
device = "cuda"
|
||||||
|
bs, reqs = self.batch_size(), self.reqs
|
||||||
self.temperatures = torch.tensor(
|
self.temperatures = torch.tensor(
|
||||||
[r.sampling_params.temperature for r in reqs],
|
[r.sampling_params.temperature for r in reqs],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
@@ -436,10 +383,78 @@ class ScheduleBatch:
|
|||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
self.logit_bias = logit_bias
|
|
||||||
|
# Handle logit bias but only allocate when needed
|
||||||
|
self.logit_bias = None
|
||||||
|
for i in range(bs):
|
||||||
|
if reqs[i].sampling_params.dtype == "int":
|
||||||
|
if self.logit_bias is None:
|
||||||
|
self.logit_bias = torch.zeros(
|
||||||
|
(bs, vocab_size), dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||||
|
|
||||||
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||||
|
device = "cuda"
|
||||||
|
bs = self.batch_size()
|
||||||
|
reqs = self.reqs
|
||||||
|
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
|
prefix_indices = [r.prefix_indices for r in reqs]
|
||||||
|
|
||||||
|
# Handle prefix
|
||||||
|
extend_lens = []
|
||||||
|
prefix_lens = []
|
||||||
|
seq_lens = []
|
||||||
|
|
||||||
|
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
||||||
|
|
||||||
|
for i, req in enumerate(reqs):
|
||||||
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||||
|
extend_lens.append(len(input_ids[i]))
|
||||||
|
|
||||||
|
if len(prefix_indices[i]) == 0:
|
||||||
|
prefix_lens.append(0)
|
||||||
|
else:
|
||||||
|
prefix_lens.append(len(prefix_indices[i]))
|
||||||
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
: len(prefix_indices[i])
|
||||||
|
] = prefix_indices[i]
|
||||||
|
|
||||||
|
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||||
|
|
||||||
|
# Allocate memory
|
||||||
|
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||||
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||||
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||||
|
|
||||||
|
pt = 0
|
||||||
|
for i, req in enumerate(reqs):
|
||||||
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||||
|
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||||
|
pt += extend_lens[i]
|
||||||
|
|
||||||
|
# Set fields
|
||||||
|
with torch.device("cuda"):
|
||||||
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||||
|
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
||||||
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||||
|
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
|
||||||
|
|
||||||
|
self.pixel_values = [r.pixel_values for r in reqs]
|
||||||
|
self.image_sizes = [r.image_size for r in reqs]
|
||||||
|
self.image_offsets = [
|
||||||
|
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||||
|
]
|
||||||
|
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
self.extend_num_tokens = extend_num_tokens
|
||||||
|
self.out_cache_loc = out_cache_loc
|
||||||
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
|
|
||||||
|
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
||||||
|
|
||||||
def check_decode_mem(self):
|
def check_decode_mem(self):
|
||||||
bs = len(self.reqs)
|
bs = self.batch_size()
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -464,7 +479,6 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
retracted_reqs = []
|
retracted_reqs = []
|
||||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
|
||||||
while (
|
while (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
< len(sorted_indices) * global_config.retract_decode_steps
|
< len(sorted_indices) * global_config.retract_decode_steps
|
||||||
@@ -482,20 +496,20 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
if isinstance(self.tree_cache, ChunkCache):
|
if isinstance(self.tree_cache, ChunkCache):
|
||||||
# ChunkCache does not have eviction
|
# ChunkCache does not have eviction
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
req_pool_indices_cpu[idx]
|
: seq_lens_cpu[idx]
|
||||||
][: seq_lens_cpu[idx]]
|
]
|
||||||
self.token_to_kv_pool.free(token_indices)
|
self.token_to_kv_pool.free(token_indices)
|
||||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
del self.tree_cache.entries[req.rid]
|
del self.tree_cache.entries[req.rid]
|
||||||
else:
|
else:
|
||||||
# TODO: apply more fine-grained retraction
|
# TODO: apply more fine-grained retraction
|
||||||
last_uncached_pos = len(req.prefix_indices)
|
last_uncached_pos = len(req.prefix_indices)
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
req_pool_indices_cpu[idx]
|
last_uncached_pos : seq_lens_cpu[idx]
|
||||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
]
|
||||||
self.token_to_kv_pool.free(token_indices)
|
self.token_to_kv_pool.free(token_indices)
|
||||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
# release the last node
|
# release the last node
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
@@ -533,8 +547,6 @@ class ScheduleBatch:
|
|||||||
jump_forward_reqs = []
|
jump_forward_reqs = []
|
||||||
filter_indices = [i for i in range(len(self.reqs))]
|
filter_indices = [i for i in range(len(self.reqs))]
|
||||||
|
|
||||||
req_pool_indices_cpu = None
|
|
||||||
|
|
||||||
for i, req in enumerate(self.reqs):
|
for i, req in enumerate(self.reqs):
|
||||||
if req.jump_forward_map is not None:
|
if req.jump_forward_map is not None:
|
||||||
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
||||||
@@ -584,13 +596,11 @@ class ScheduleBatch:
|
|||||||
req.vid += 1
|
req.vid += 1
|
||||||
|
|
||||||
# insert the old request into tree_cache
|
# insert the old request into tree_cache
|
||||||
if req_pool_indices_cpu is None:
|
|
||||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
|
||||||
self.tree_cache.cache_req(
|
self.tree_cache.cache_req(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
token_ids=cur_all_ids,
|
token_ids=cur_all_ids,
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req.req_pool_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
# unlock the last node
|
# unlock the last node
|
||||||
@@ -626,14 +636,8 @@ class ScheduleBatch:
|
|||||||
self.prefix_lens = None
|
self.prefix_lens = None
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = len(self.reqs)
|
bs = self.batch_size()
|
||||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||||
|
|
||||||
if self.out_cache_loc is None:
|
|
||||||
logger.error("Decode out of memory. Try to lower your batch size.")
|
|
||||||
if self.tree_cache is not None:
|
|
||||||
self.tree_cache.pretty_print()
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[
|
||||||
self.req_pool_indices, self.seq_lens - 1
|
self.req_pool_indices, self.seq_lens - 1
|
||||||
|
|||||||
@@ -200,7 +200,6 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
self.new_token_ratio = self.min_new_token_ratio
|
self.new_token_ratio = self.min_new_token_ratio
|
||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
try:
|
try:
|
||||||
@@ -625,13 +624,12 @@ class ModelTpServer:
|
|||||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||||
|
|
||||||
def cache_filled_batch(self, batch: ScheduleBatch):
|
def cache_filled_batch(self, batch: ScheduleBatch):
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
token_ids=tuple(req.input_ids),
|
token_ids=tuple(req.input_ids),
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req.req_pool_idx,
|
||||||
del_in_memory_pool=False,
|
del_in_memory_pool=False,
|
||||||
old_last_node=req.last_node,
|
old_last_node=req.last_node,
|
||||||
)
|
)
|
||||||
@@ -639,7 +637,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
if req is self.current_inflight_req:
|
if req is self.current_inflight_req:
|
||||||
# inflight request would get a new req idx
|
# inflight request would get a new req idx
|
||||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
@@ -782,14 +780,13 @@ class ModelTpServer:
|
|||||||
# Remove finished reqs
|
# Remove finished reqs
|
||||||
if finished_indices:
|
if finished_indices:
|
||||||
# Update radix cache
|
# Update radix cache
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
|
||||||
for i in finished_indices:
|
for i in finished_indices:
|
||||||
req = batch.reqs[i]
|
req = batch.reqs[i]
|
||||||
self.tree_cache.cache_req(
|
self.tree_cache.cache_req(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req.req_pool_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
"""Memory pool."""
|
"""Memory pool."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -27,34 +28,29 @@ class ReqToTokenPool:
|
|||||||
|
|
||||||
def __init__(self, size: int, max_context_len: int):
|
def __init__(self, size: int, max_context_len: int):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
self.free_slots = list(range(size))
|
||||||
self.req_to_token = torch.empty(
|
self.req_to_token = torch.empty(
|
||||||
(size, max_context_len), dtype=torch.int32, device="cuda"
|
(size, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.can_use_mem_size = size
|
self.can_use_mem_size = size
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int) -> List[int]:
|
||||||
if need_size > self.can_use_mem_size:
|
if need_size > len(self.free_slots):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
select_index = (
|
select_index = self.free_slots[:need_size]
|
||||||
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
|
self.free_slots = self.free_slots[need_size:]
|
||||||
)
|
|
||||||
self.mem_state[select_index] = False
|
|
||||||
self.can_use_mem_size -= need_size
|
|
||||||
|
|
||||||
return select_index
|
return select_index
|
||||||
|
|
||||||
def free(self, free_index):
|
def free(self, free_index):
|
||||||
self.mem_state[free_index] = True
|
|
||||||
if isinstance(free_index, (int,)):
|
if isinstance(free_index, (int,)):
|
||||||
self.can_use_mem_size += 1
|
self.free_slots.append(free_index)
|
||||||
else:
|
else:
|
||||||
self.can_use_mem_size += free_index.shape[0]
|
self.free_slots.extend(free_index)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_state.fill_(True)
|
self.free_slots = list(range(self.size))
|
||||||
self.can_use_mem_size = len(self.mem_state)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenToKVPool:
|
class BaseTokenToKVPool:
|
||||||
|
|||||||
Reference in New Issue
Block a user