Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)

This commit is contained in:
Lianmin Zheng
2025-06-15 02:48:00 -07:00
committed by GitHub
parent 5f1ab32717
commit fff10809bf
7 changed files with 150 additions and 647 deletions

View File

@@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache):
for _ in range(self.layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
for x in self.k_buffer + self.v_buffer
],
device=self.device,
)
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
@@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
len(tgt_loc),
next_power_of_2(len(tgt_loc)),
)
@triton.jit
def set_mla_kv_buffer_kernel(
@@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
@triton.jit
def copy_all_layer_kv_cache(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
bid = tl.program_id(0)
stride = tl.load(strides + bid)
data_ptr = tl.load(data_ptrs + bid)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
num_locs_offset = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
# because this copy is an inplace operation.
num_loop = tl.cdiv(stride, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
value = tl.load(
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)