Improve benchmark scripts (#615)

This commit is contained in:
Lianmin Zheng
2024-07-13 15:59:04 -07:00
committed by GitHub
parent 10143e1a5f
commit 0feca02dd9
2 changed files with 6 additions and 5 deletions

View File

@@ -52,7 +52,7 @@ class TokenToKVPool:
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 256
self.prefetch_chunk_size = 512
self.clear()
@@ -67,11 +67,11 @@ class TokenToKVPool:
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index.to(torch.int32)
return select_index
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size]
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
if select_index.shape[0] < addition_size:
return None
@@ -82,7 +82,7 @@ class TokenToKVPool:
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index.to(torch.int32)
return ret_index
def alloc_contiguous(self, need_size):
# NOTE: This function is deprecated.