Improve benchmark scripts (#615)
This commit is contained in:
@@ -52,7 +52,7 @@ class TokenToKVPool:
|
||||
|
||||
# Prefetch buffer
|
||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||
self.prefetch_chunk_size = 256
|
||||
self.prefetch_chunk_size = 512
|
||||
|
||||
self.clear()
|
||||
|
||||
@@ -67,11 +67,11 @@ class TokenToKVPool:
|
||||
if need_size <= buffer_len:
|
||||
select_index = self.prefetch_buffer[:need_size]
|
||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
||||
return select_index.to(torch.int32)
|
||||
return select_index
|
||||
|
||||
addition_size = need_size - buffer_len
|
||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size]
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
|
||||
if select_index.shape[0] < addition_size:
|
||||
return None
|
||||
@@ -82,7 +82,7 @@ class TokenToKVPool:
|
||||
ret_index = self.prefetch_buffer[:need_size]
|
||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
||||
|
||||
return ret_index.to(torch.int32)
|
||||
return ret_index
|
||||
|
||||
def alloc_contiguous(self, need_size):
|
||||
# NOTE: This function is deprecated.
|
||||
|
||||
Reference in New Issue
Block a user