Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)
This commit is contained in:
@@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache):
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
)
|
||||
self.data_strides = torch.tensor(
|
||||
[
|
||||
np.prod(x.shape[1:]) * x.dtype.itemsize
|
||||
for x in self.k_buffer + self.v_buffer
|
||||
],
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
@@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
||||
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
||||
self.data_ptrs,
|
||||
self.data_strides,
|
||||
tgt_loc,
|
||||
src_loc,
|
||||
len(tgt_loc),
|
||||
next_power_of_2(len(tgt_loc)),
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
@@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def copy_all_layer_kv_cache(
|
||||
data_ptrs,
|
||||
strides,
|
||||
tgt_loc_ptr,
|
||||
src_loc_ptr,
|
||||
num_locs,
|
||||
num_locs_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
|
||||
bid = tl.program_id(0)
|
||||
stride = tl.load(strides + bid)
|
||||
|
||||
data_ptr = tl.load(data_ptrs + bid)
|
||||
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
||||
|
||||
num_locs_offset = tl.arange(0, num_locs_upper)
|
||||
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||
|
||||
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
||||
# because this copy is an inplace operation.
|
||||
|
||||
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
|
||||
value = tl.load(
|
||||
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
||||
)
|
||||
tl.store(
|
||||
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user