Files
sglang/python/sglang/srt/memory_pool.py
2024-07-13 15:24:03 -07:00

130 lines
4.3 KiB
Python

"""Memory pool."""
import logging
import torch
logger = logging.getLogger(__name__)
class ReqToTokenPool:
def __init__(self, size, max_context_len):
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = size
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
def alloc(self, need_size):
if need_size > self.can_use_mem_size:
return None
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
self.mem_state[select_index] = 0
self.can_use_mem_size -= need_size
return select_index.to(torch.int32)
def free(self, free_index):
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1
def clear(self):
self.mem_state.fill_(1)
self.can_use_mem_size = len(self.mem_state)
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.size = size
# mem_state is the reference counter.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
self.total_ref_ct = 0
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 256
self.clear()
def get_key_buffer(self, layer_id):
return self.kv_data[layer_id][:, 0]
def get_value_buffer(self, layer_id):
return self.kv_data[layer_id][:, 1]
def alloc(self, need_size):
buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index.to(torch.int32)
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size]
if select_index.shape[0] < addition_size:
return None
self.add_refs(select_index)
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index.to(torch.int32)
def alloc_contiguous(self, need_size):
# NOTE: This function is deprecated.
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if empty_index.shape[0] < need_size:
return None
empty_size = len(empty_index)
loc_sum = (
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
)
can_used_loc = empty_index[: empty_size - (need_size - 1)][
loc_sum == need_size - 1
]
if can_used_loc.shape[0] == 0:
return None
start_loc = can_used_loc[0].item()
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
self.add_refs(select_index)
return select_index.to(torch.int32), start_loc, start_loc + need_size
def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1))
def available_size(self):
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
def add_refs(self, token_index: torch.Tensor):
self.total_ref_ct += len(token_index)
self.mem_state[token_index] += 1
def dec_refs(self, token_index: torch.Tensor):
self.total_ref_ct -= len(token_index)
self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0)
return num_freed
def clear(self):
self.mem_state.fill_(0)
self.total_ref_ct = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))