diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 46010ccf7..647be2810 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -20,8 +20,9 @@ class ReqToTokenPool: return None select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size] - self.mem_state[select_index] = 0 + self.mem_state[select_index] = False self.can_use_mem_size -= need_size + return select_index.to(torch.int32) def free(self, free_index): @@ -29,22 +30,23 @@ class ReqToTokenPool: self.can_use_mem_size += 1 else: self.can_use_mem_size += free_index.shape[0] - self.mem_state[free_index] = 1 + + self.mem_state[free_index] = True def clear(self): - self.mem_state.fill_(1) + self.mem_state.fill_(True) self.can_use_mem_size = len(self.mem_state) class TokenToKVPool: def __init__(self, size, dtype, head_num, head_dim, layer_num): self.size = size + # This can be promised: # assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0) # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda") - self.total_size = self.size - self.total_alloc = 0 + self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") + self.can_use_mem_size = self.size # [size, key/value, head_num, head_dim] for each layer self.kv_data = [ @@ -73,9 +75,8 @@ class TokenToKVPool: addition_size = need_size - buffer_len alloc_size = max(addition_size, self.prefetch_chunk_size) - select_index = ( - torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32) - ) + select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size] + select_index = select_index.to(torch.int32) if select_index.shape[0] < addition_size: return None @@ -88,43 +89,20 @@ class TokenToKVPool: return ret_index - def alloc_contiguous(self, need_size): - # NOTE: This function is deprecated. - empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] - if empty_index.shape[0] < need_size: - return None - empty_size = len(empty_index) - loc_sum = ( - empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)] - ) - can_used_loc = empty_index[: empty_size - (need_size - 1)][ - loc_sum == need_size - 1 - ] - if can_used_loc.shape[0] == 0: - return None - - start_loc = can_used_loc[0].item() - select_index = torch.arange(start_loc, start_loc + need_size, device="cuda") - self.add_refs(select_index) - return select_index.to(torch.int32), start_loc, start_loc + need_size - - def used_size(self): - return self.total_alloc - def available_size(self): - return self.total_size - self.total_alloc + len(self.prefetch_buffer) + return self.can_use_mem_size + len(self.prefetch_buffer) def add_refs(self, token_index: torch.Tensor): - self.total_alloc += len(token_index) - self.mem_state[token_index] ^= True + self.can_use_mem_size -= len(token_index) + self.mem_state[token_index] = False def dec_refs(self, token_index: torch.Tensor): - self.total_alloc -= len(token_index) - self.mem_state[token_index] ^= True + self.can_use_mem_size += len(token_index) + self.mem_state[token_index] = True def clear(self): - self.mem_state.fill_(0) - self.total_alloc = 0 + self.mem_state.fill_(True) + self.can_use_mem_size = self.size # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.mem_state[0] = True + self.mem_state[0] = False