Make req_pool_indices on CPU (#960)
This commit is contained in:
@@ -16,6 +16,7 @@ limitations under the License.
|
||||
"""Memory pool."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,34 +28,29 @@ class ReqToTokenPool:
|
||||
|
||||
def __init__(self, size: int, max_context_len: int):
|
||||
self.size = size
|
||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
||||
self.free_slots = list(range(size))
|
||||
self.req_to_token = torch.empty(
|
||||
(size, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.can_use_mem_size = size
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
if need_size > self.can_use_mem_size:
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = (
|
||||
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
|
||||
)
|
||||
self.mem_state[select_index] = False
|
||||
self.can_use_mem_size -= need_size
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index):
|
||||
self.mem_state[free_index] = True
|
||||
if isinstance(free_index, (int,)):
|
||||
self.can_use_mem_size += 1
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.can_use_mem_size += free_index.shape[0]
|
||||
self.free_slots.extend(free_index)
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(True)
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
|
||||
Reference in New Issue
Block a user