Enable cuda graph by default (#612)

This commit is contained in:
Lianmin Zheng
2024-07-13 05:29:46 -07:00
committed by GitHub
parent 396a69240f
commit 665815969a
10 changed files with 331 additions and 84 deletions

View File

@@ -38,7 +38,10 @@ class ReqToTokenPool:
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
self.size = size
# mem_state is the reference counter.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
self.total_ref_ct = 0
# [size, key/value, head_num, head_dim] for each layer
@@ -47,6 +50,8 @@ class TokenToKVPool:
for _ in range(layer_num)
]
self.clear()
def get_key_buffer(self, layer_id):
return self.kv_data[layer_id][:, 0]
@@ -101,3 +106,6 @@ class TokenToKVPool:
def clear(self):
self.mem_state.fill_(0)
self.total_ref_ct = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))