Memorypool chunked prefetch (#614)
This commit is contained in:
@@ -141,12 +141,5 @@ class RadixAttention(nn.Module):
|
||||
if input_metadata.out_cache_loc is not None:
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
elif input_metadata.out_cache_cont_start is not None:
|
||||
key_buffer[
|
||||
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
||||
] = cache_k
|
||||
value_buffer[
|
||||
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
||||
] = cache_v
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
@@ -104,8 +104,6 @@ class CudaGraphRunner:
|
||||
prefix_lens=None,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=0,
|
||||
skip_flashinfer_init=True,
|
||||
|
||||
@@ -275,8 +275,6 @@ class Batch:
|
||||
prefix_lens: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: int = None
|
||||
out_cache_cont_end: int = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
@@ -566,21 +564,12 @@ class Batch:
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
|
||||
if alloc_res is None:
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
print("Decode out of memory. This should never happen.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
self.out_cache_cont_start = None
|
||||
self.out_cache_cont_end = None
|
||||
else:
|
||||
self.out_cache_loc = alloc_res[0]
|
||||
self.out_cache_cont_start = alloc_res[1]
|
||||
self.out_cache_cont_end = alloc_res[2]
|
||||
if self.out_cache_loc is None:
|
||||
print("Decode out of memory. This should never happen.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
@@ -594,7 +583,7 @@ class Batch:
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
@@ -622,7 +611,7 @@ class Batch:
|
||||
self.position_ids_offsets = torch.concat(
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
@@ -729,8 +718,6 @@ class InputMetadata:
|
||||
|
||||
# Output location of the KV cache
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: int = None
|
||||
out_cache_cont_end: int = None
|
||||
|
||||
# Output options
|
||||
return_logprob: bool = False
|
||||
@@ -757,8 +744,6 @@ class InputMetadata:
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
skip_flashinfer_init=False,
|
||||
@@ -811,8 +796,6 @@ class InputMetadata:
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_start_loc=extend_start_loc,
|
||||
extend_no_prefix=extend_no_prefix,
|
||||
|
||||
@@ -245,8 +245,6 @@ class ModelRunner:
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
out_cache_cont_start=batch.out_cache_cont_start,
|
||||
out_cache_cont_end=batch.out_cache_cont_end,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
)
|
||||
|
||||
@@ -50,6 +50,10 @@ class TokenToKVPool:
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# Prefetch buffer
|
||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||
self.prefetch_chunk_size = 256
|
||||
|
||||
self.clear()
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
@@ -59,14 +63,29 @@ class TokenToKVPool:
|
||||
return self.kv_data[layer_id][:, 1]
|
||||
|
||||
def alloc(self, need_size):
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if select_index.shape[0] < need_size:
|
||||
buffer_len = len(self.prefetch_buffer)
|
||||
if need_size <= buffer_len:
|
||||
select_index = self.prefetch_buffer[:need_size]
|
||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
addition_size = need_size - buffer_len
|
||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size]
|
||||
|
||||
if select_index.shape[0] < addition_size:
|
||||
return None
|
||||
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
||||
ret_index = self.prefetch_buffer[:need_size]
|
||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
||||
|
||||
return ret_index.to(torch.int32)
|
||||
|
||||
def alloc_contiguous(self, need_size):
|
||||
# NOTE: This function is deprecated.
|
||||
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if empty_index.shape[0] < need_size:
|
||||
return None
|
||||
@@ -89,7 +108,7 @@ class TokenToKVPool:
|
||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||
|
||||
def available_size(self):
|
||||
return torch.sum(self.mem_state == 0).item()
|
||||
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.total_ref_ct += len(token_index)
|
||||
|
||||
Reference in New Issue
Block a user