From 10143e1a5f2ee5826f7e566432d29e221d8c4af0 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 13 Jul 2024 15:24:03 -0700 Subject: [PATCH] Memorypool chunked prefetch (#614) --- python/sglang/srt/layers/radix_attention.py | 7 ----- .../managers/controller/cuda_graph_runner.py | 2 -- .../srt/managers/controller/infer_batch.py | 31 +++++-------------- .../srt/managers/controller/model_runner.py | 2 -- python/sglang/srt/memory_pool.py | 27 +++++++++++++--- 5 files changed, 30 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 2964ae5b2..73e122a75 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -141,12 +141,5 @@ class RadixAttention(nn.Module): if input_metadata.out_cache_loc is not None: key_buffer[input_metadata.out_cache_loc] = cache_k value_buffer[input_metadata.out_cache_loc] = cache_v - elif input_metadata.out_cache_cont_start is not None: - key_buffer[ - input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end - ] = cache_k - value_buffer[ - input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end - ] = cache_v else: raise RuntimeError() diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index eee1bb81f..ad3225aa6 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -104,8 +104,6 @@ class CudaGraphRunner: prefix_lens=None, position_ids_offsets=position_ids_offsets, out_cache_loc=out_cache_loc, - out_cache_cont_start=None, - out_cache_cont_end=None, return_logprob=False, top_logprobs_nums=0, skip_flashinfer_init=True, diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index d1bc60f9d..d89e9786e 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -275,8 +275,6 @@ class Batch: prefix_lens: torch.Tensor = None position_ids_offsets: torch.Tensor = None out_cache_loc: torch.Tensor = None - out_cache_cont_start: int = None - out_cache_cont_end: int = None # For processing logprobs return_logprob: bool = False @@ -566,21 +564,12 @@ class Batch: # Alloc mem bs = len(self.reqs) - alloc_res = self.token_to_kv_pool.alloc_contiguous(bs) - if alloc_res is None: - self.out_cache_loc = self.token_to_kv_pool.alloc(bs) + self.out_cache_loc = self.token_to_kv_pool.alloc(bs) - if self.out_cache_loc is None: - print("Decode out of memory. This should never happen.") - self.tree_cache.pretty_print() - exit() - - self.out_cache_cont_start = None - self.out_cache_cont_end = None - else: - self.out_cache_loc = alloc_res[0] - self.out_cache_cont_start = alloc_res[1] - self.out_cache_cont_end = alloc_res[2] + if self.out_cache_loc is None: + print("Decode out of memory. This should never happen.") + self.tree_cache.pretty_print() + exit() self.req_to_token_pool.req_to_token[ self.req_pool_indices, self.seq_lens - 1 @@ -594,7 +583,7 @@ class Batch: self.req_pool_indices = self.req_pool_indices[new_indices] self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] - self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.out_cache_loc = None self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) @@ -622,7 +611,7 @@ class Batch: self.position_ids_offsets = torch.concat( [self.position_ids_offsets, other.position_ids_offsets] ) - self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.out_cache_loc = None self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) @@ -729,8 +718,6 @@ class InputMetadata: # Output location of the KV cache out_cache_loc: torch.Tensor = None - out_cache_cont_start: int = None - out_cache_cont_end: int = None # Output options return_logprob: bool = False @@ -757,8 +744,6 @@ class InputMetadata: prefix_lens, position_ids_offsets, out_cache_loc, - out_cache_cont_start=None, - out_cache_cont_end=None, top_logprobs_nums=None, return_logprob=False, skip_flashinfer_init=False, @@ -811,8 +796,6 @@ class InputMetadata: req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, - out_cache_cont_start=out_cache_cont_start, - out_cache_cont_end=out_cache_cont_end, extend_seq_lens=extend_seq_lens, extend_start_loc=extend_start_loc, extend_no_prefix=extend_no_prefix, diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index a24653661..315dd4d66 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -245,8 +245,6 @@ class ModelRunner: prefix_lens=batch.prefix_lens, position_ids_offsets=batch.position_ids_offsets, out_cache_loc=batch.out_cache_loc, - out_cache_cont_start=batch.out_cache_cont_start, - out_cache_cont_end=batch.out_cache_cont_end, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, ) diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 51b9beeb2..d586be433 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -50,6 +50,10 @@ class TokenToKVPool: for _ in range(layer_num) ] + # Prefetch buffer + self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) + self.prefetch_chunk_size = 256 + self.clear() def get_key_buffer(self, layer_id): @@ -59,14 +63,29 @@ class TokenToKVPool: return self.kv_data[layer_id][:, 1] def alloc(self, need_size): - select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] - if select_index.shape[0] < need_size: + buffer_len = len(self.prefetch_buffer) + if need_size <= buffer_len: + select_index = self.prefetch_buffer[:need_size] + self.prefetch_buffer = self.prefetch_buffer[need_size:] + return select_index.to(torch.int32) + + addition_size = need_size - buffer_len + alloc_size = max(addition_size, self.prefetch_chunk_size) + select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size] + + if select_index.shape[0] < addition_size: return None self.add_refs(select_index) - return select_index.to(torch.int32) + + self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index)) + ret_index = self.prefetch_buffer[:need_size] + self.prefetch_buffer = self.prefetch_buffer[need_size:] + + return ret_index.to(torch.int32) def alloc_contiguous(self, need_size): + # NOTE: This function is deprecated. empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] if empty_index.shape[0] < need_size: return None @@ -89,7 +108,7 @@ class TokenToKVPool: return len(torch.nonzero(self.mem_state).squeeze(1)) def available_size(self): - return torch.sum(self.mem_state == 0).item() + return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer) def add_refs(self, token_index: torch.Tensor): self.total_ref_ct += len(token_index)