Make req_pool_indices on CPU (#960)

This commit is contained in:
Liangsheng Yin
2024-08-07 01:41:25 -07:00
committed by GitHub
parent 05abd1261c
commit 7fa54a1ab3
4 changed files with 110 additions and 114 deletions

View File

@@ -16,6 +16,7 @@ limitations under the License.
"""Memory pool."""
import logging
from typing import List
import torch
@@ -27,34 +28,29 @@ class ReqToTokenPool:
def __init__(self, size: int, max_context_len: int):
self.size = size
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.free_slots = list(range(size))
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
self.can_use_mem_size = size
def alloc(self, need_size: int):
if need_size > self.can_use_mem_size:
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
)
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
self.free_slots.append(free_index)
else:
self.can_use_mem_size += free_index.shape[0]
self.free_slots.extend(free_index)
def clear(self):
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)
self.free_slots = list(range(self.size))
class BaseTokenToKVPool: