Support double sparsity (#1459)
This commit is contained in:
@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
else:
|
||||
self.kv_buffer[layer_id][loc] = cache_k
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
heavy_channel_num: int,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_label_buffer(self, layer_id: int):
|
||||
return self.label_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
cache_label: torch.Tensor,
|
||||
):
|
||||
# NOTE(Andy): ignore the dtype check
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
self.label_buffer[layer_id][loc] = cache_label
|
||||
|
||||
Reference in New Issue
Block a user