Support double sparsity (#1459)

This commit is contained in:
Shuo Yang
2024-10-14 02:00:41 -07:00
committed by GitHub
parent 0c1e87964b
commit 061e546313
8 changed files with 1269 additions and 1 deletions

View File

@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
else:
self.kv_buffer[layer_id][loc] = cache_k
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
heavy_channel_num: int,
):
super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_label_buffer(self, layer_id: int):
return self.label_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
cache_label: torch.Tensor,
):
# NOTE(Andy): ignore the dtype check
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label