Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)
This commit is contained in:
@@ -1049,13 +1049,14 @@ class FlashInferMultiStepDraftBackend:
|
||||
kv_indices_buffer,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
|
||||
@@ -789,7 +789,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
self.page_size = model_runner.server_args.page_size
|
||||
|
||||
def common_template(
|
||||
self,
|
||||
@@ -810,13 +809,14 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
kv_indices_buffer,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
|
||||
@@ -784,13 +784,14 @@ class TritonMultiStepDraftBackend:
|
||||
kv_indices_buffer,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
|
||||
@@ -294,19 +294,6 @@ class MHATokenToKVPool(KVCache):
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
)
|
||||
self.data_strides = torch.tensor(
|
||||
[
|
||||
np.prod(x.shape[1:]) * x.dtype.itemsize
|
||||
for x in self.k_buffer + self.v_buffer
|
||||
],
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
@@ -464,16 +451,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
||||
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
||||
self.data_ptrs,
|
||||
self.data_strides,
|
||||
tgt_loc,
|
||||
src_loc,
|
||||
len(tgt_loc),
|
||||
next_power_of_2(len(tgt_loc)),
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
@@ -764,41 +741,3 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def copy_all_layer_kv_cache(
|
||||
data_ptrs,
|
||||
strides,
|
||||
tgt_loc_ptr,
|
||||
src_loc_ptr,
|
||||
num_locs,
|
||||
num_locs_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
|
||||
bid = tl.program_id(0)
|
||||
stride = tl.load(strides + bid)
|
||||
|
||||
data_ptr = tl.load(data_ptrs + bid)
|
||||
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
||||
|
||||
num_locs_offset = tl.arange(0, num_locs_upper)
|
||||
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||
|
||||
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
||||
# because this copy is an inplace operation.
|
||||
|
||||
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
|
||||
value = tl.load(
|
||||
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
||||
)
|
||||
tl.store(
|
||||
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
@@ -67,6 +67,8 @@ class EagleDraftInput:
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
|
||||
all_padding_lens: Optional[torch.Tensor] = None
|
||||
|
||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||
# Prefill only generate 1 token.
|
||||
assert len(self.verified_id) == len(batch.seq_lens)
|
||||
@@ -91,7 +93,6 @@ class EagleDraftInput:
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
batch.return_logprob = False
|
||||
batch.return_hidden_states = False
|
||||
|
||||
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.accept_length.add_(1)
|
||||
@@ -115,8 +116,10 @@ class EagleDraftInput:
|
||||
req_to_token: torch.Tensor,
|
||||
):
|
||||
bs = self.accept_length.numel()
|
||||
|
||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||
|
||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
@@ -136,6 +139,7 @@ class EagleDraftInput:
|
||||
kv_indices,
|
||||
req_to_token.size(1),
|
||||
)
|
||||
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor):
|
||||
@@ -266,7 +270,7 @@ class EagleVerifyInput:
|
||||
logits_output: torch.Tensor,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
||||
vocab_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Verify and find accepted tokens based on logits output and batch
|
||||
@@ -290,14 +294,6 @@ class EagleVerifyInput:
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||
|
||||
# Apply the custom logit processors if registered in the sampling info.
|
||||
if sampling_info.has_custom_logit_processor:
|
||||
apply_custom_logit_processor(
|
||||
logits_output.next_token_logits,
|
||||
sampling_info,
|
||||
num_tokens_in_batch=self.draft_token_num,
|
||||
)
|
||||
|
||||
# Apply penalty
|
||||
if sampling_info.penalizer_orchestrator.is_required:
|
||||
# This is a relaxed version of penalties for speculative decoding.
|
||||
@@ -359,13 +355,7 @@ class EagleVerifyInput:
|
||||
draft_probs = torch.zeros(
|
||||
target_probs.shape, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
# coins for rejection sampling
|
||||
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
||||
# coins for final sampling
|
||||
coins_for_final_sampling = torch.rand(
|
||||
(bs,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=predict, # mutable
|
||||
accept_index=accept_index, # mutable
|
||||
@@ -375,7 +365,6 @@ class EagleVerifyInput:
|
||||
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
||||
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
||||
uniform_samples=coins,
|
||||
# uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=global_server_args_dict[
|
||||
@@ -398,8 +387,8 @@ class EagleVerifyInput:
|
||||
spec_steps=self.spec_steps,
|
||||
)
|
||||
|
||||
new_accept_index = []
|
||||
unfinished_index = []
|
||||
unfinished_accept_index = []
|
||||
accept_index_cpu = accept_index.tolist()
|
||||
predict_cpu = predict.tolist()
|
||||
has_finished = False
|
||||
@@ -407,10 +396,12 @@ class EagleVerifyInput:
|
||||
# Iterate every accepted token and check if req has finished after append the token
|
||||
# should be checked BEFORE free kv cache slots
|
||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||
new_accept_index_ = []
|
||||
for j, idx in enumerate(accept_index_row):
|
||||
if idx == -1:
|
||||
break
|
||||
id = predict_cpu[idx]
|
||||
# if not found_finished:
|
||||
req.output_ids.append(id)
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
@@ -419,6 +410,8 @@ class EagleVerifyInput:
|
||||
accept_index[i, j + 1 :] = -1
|
||||
break
|
||||
else:
|
||||
new_accept_index_.append(idx)
|
||||
# update grammar state
|
||||
if req.grammar is not None:
|
||||
try:
|
||||
req.grammar.accept_token(id)
|
||||
@@ -428,104 +421,50 @@ class EagleVerifyInput:
|
||||
)
|
||||
raise e
|
||||
if not req.finished():
|
||||
new_accept_index.extend(new_accept_index_)
|
||||
unfinished_index.append(i)
|
||||
if idx == -1:
|
||||
unfinished_accept_index.append(accept_index[i, :j])
|
||||
else:
|
||||
unfinished_accept_index.append(accept_index[i])
|
||||
req.spec_verify_ct += 1
|
||||
|
||||
if has_finished:
|
||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||
|
||||
# Free the KV cache for unaccepted tokens
|
||||
# TODO: fuse them
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
verified_id = predict[accept_index]
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
|
||||
if page_size == 1:
|
||||
# TODO: boolean array index leads to a device sync. Remove it.
|
||||
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||
else:
|
||||
if self.topk == 1:
|
||||
# Only evict full empty page. Do not evict partial empty page
|
||||
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
||||
batch.seq_lens,
|
||||
evict_mask,
|
||||
page_size,
|
||||
self.draft_token_num,
|
||||
next_power_of_2(self.draft_token_num),
|
||||
)
|
||||
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||
else:
|
||||
# Shift the accepted tokens to the beginning.
|
||||
# Only evict the last part
|
||||
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
||||
batch.seq_lens,
|
||||
batch.out_cache_loc,
|
||||
accept_index,
|
||||
accept_length,
|
||||
self.draft_token_num,
|
||||
page_size,
|
||||
)
|
||||
to_free_slots = torch.empty(
|
||||
(to_free_num_slots.sum().item(),),
|
||||
dtype=torch.int64,
|
||||
device=to_free_num_slots.device,
|
||||
)
|
||||
if page_size != 1:
|
||||
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
||||
batch.seq_lens,
|
||||
evict_mask,
|
||||
page_size,
|
||||
self.draft_token_num,
|
||||
next_power_of_2(self.draft_token_num),
|
||||
)
|
||||
|
||||
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
||||
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
||||
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
||||
# to_free_slots: [ 2, 5, 7 8]
|
||||
# to_free_slots also needs to be page-aligned without the first partial page
|
||||
#
|
||||
# split each row of out_cache_loc into two parts.
|
||||
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
||||
# 2. the second part goes to to_free_slots.
|
||||
get_target_cache_loc[(bs,)](
|
||||
tgt_cache_loc,
|
||||
to_free_slots,
|
||||
accept_length,
|
||||
to_free_num_slots,
|
||||
batch.out_cache_loc,
|
||||
self.draft_token_num,
|
||||
next_power_of_2(self.draft_token_num),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
# Free the kv cache
|
||||
token_to_kv_pool_allocator.free(to_free_slots)
|
||||
|
||||
# Copy the kv cache
|
||||
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
||||
tgt_cache_loc, src_cache_loc
|
||||
)
|
||||
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||
|
||||
# Construct EagleVerifyOutput
|
||||
if not has_finished:
|
||||
if page_size == 1 or self.topk == 1:
|
||||
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + accept_length + 1,
|
||||
batch.out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
else:
|
||||
batch.out_cache_loc = tgt_cache_loc
|
||||
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + accept_length + 1,
|
||||
batch.out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
|
||||
draft_input = EagleDraftInput()
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
||||
draft_input.verified_id = verified_id
|
||||
draft_input.accept_length = accept_length
|
||||
draft_input.accept_length_cpu = accept_length.tolist()
|
||||
draft_input.accept_length_cpu = accept_length_cpu
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||
|
||||
@@ -533,66 +472,47 @@ class EagleVerifyInput:
|
||||
draft_input=draft_input,
|
||||
logits_output=logits_output,
|
||||
verified_id=verified_id,
|
||||
accept_length_per_req_cpu=draft_input.accept_length_cpu,
|
||||
accept_length_per_req_cpu=accept_length_cpu,
|
||||
accepted_indices=accept_index,
|
||||
)
|
||||
else:
|
||||
if page_size == 1 or self.topk == 1:
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + accept_length + 1,
|
||||
batch.out_cache_loc[accept_index],
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + accept_length + 1,
|
||||
batch.out_cache_loc[accept_index],
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
|
||||
draft_input = EagleDraftInput()
|
||||
if len(unfinished_accept_index) > 0:
|
||||
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
||||
unfinished_index_device = torch.tensor(
|
||||
unfinished_index, dtype=torch.int64, device=predict.device
|
||||
)
|
||||
draft_input_accept_length_cpu = [
|
||||
if len(new_accept_index) > 0:
|
||||
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
||||
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||
new_accept_index
|
||||
]
|
||||
draft_input.verified_id = predict[new_accept_index]
|
||||
draft_input.accept_length_cpu = [
|
||||
accept_length_cpu[i] for i in unfinished_index
|
||||
]
|
||||
if page_size == 1 or self.topk == 1:
|
||||
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
||||
else:
|
||||
batch.out_cache_loc = torch.empty(
|
||||
len(unfinished_index) + sum(draft_input_accept_length_cpu),
|
||||
dtype=torch.int64,
|
||||
device=predict.device,
|
||||
)
|
||||
accept_length_filter = create_accept_length_filter(
|
||||
accept_length,
|
||||
unfinished_index_device,
|
||||
batch.seq_lens,
|
||||
)
|
||||
filter_finished_cache_loc_kernel[(bs,)](
|
||||
batch.out_cache_loc,
|
||||
tgt_cache_loc,
|
||||
accept_length,
|
||||
accept_length_filter,
|
||||
next_power_of_2(bs),
|
||||
next_power_of_2(self.draft_token_num),
|
||||
)
|
||||
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||
unfinished_accept_index
|
||||
]
|
||||
draft_input.verified_id = predict[unfinished_accept_index]
|
||||
draft_input.accept_length_cpu = draft_input_accept_length_cpu
|
||||
draft_input.accept_length = accept_length[unfinished_index_device]
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||
unfinished_index_device
|
||||
]
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
||||
unfinished_index_device
|
||||
]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||
unfinished_index_device
|
||||
]
|
||||
draft_input.req_pool_indices_for_draft_extend = (
|
||||
batch.req_pool_indices[unfinished_index_device]
|
||||
)
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
draft_input.req_pool_indices_for_draft_extend = (
|
||||
batch.req_pool_indices
|
||||
)
|
||||
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
||||
|
||||
return EagleVerifyOutput(
|
||||
draft_input=draft_input,
|
||||
@@ -669,75 +589,36 @@ def assign_draft_cache_locs(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
num_new_pages_per_topk,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
speculative_num_steps: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
iter_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
BLOCK_SIZE: tl.constexpr = 32
|
||||
pid = tl.program_id(axis=0)
|
||||
kv_start = tl.load(seq_lens + pid)
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
copy_len = topk * speculative_num_steps
|
||||
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||
else:
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
copy_len = tl.load(extend_lens + pid)
|
||||
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
||||
out_cache_ptr = out_cache_loc + cum_copy_len
|
||||
prefix_len = tl.load(seq_lens + pid)
|
||||
last_page_len = prefix_len % page_size
|
||||
num_new_page = (
|
||||
last_page_len + speculative_num_steps + page_size - 1
|
||||
) // page_size
|
||||
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
||||
|
||||
# Part 1: Copy from out_cache_loc to req_to_token
|
||||
kv_start = tl.load(seq_lens + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
||||
|
||||
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = copy_offset < copy_len
|
||||
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
||||
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
return
|
||||
|
||||
# Part 2: Copy the indices for the last partial page
|
||||
prefix_len = tl.load(seq_lens + pid)
|
||||
last_page_len = prefix_len % page_size
|
||||
offsets = tl.arange(0, page_size)
|
||||
mask = offsets < last_page_len
|
||||
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
||||
prefix_base = token_pool + prefix_len - last_page_len
|
||||
|
||||
for topk_id in range(topk):
|
||||
value = tl.load(prefix_base + offsets, mask=mask)
|
||||
tl.store(
|
||||
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Part 3: Remove the padding in out_cache_loc
|
||||
iter_offest = tl.arange(0, iter_upper)
|
||||
for topk_id in range(topk):
|
||||
indices = tl.load(
|
||||
prefix_base
|
||||
+ topk_id * num_new_pages_per_topk_ * page_size
|
||||
+ last_page_len
|
||||
+ iter_offest,
|
||||
mask=iter_offest < speculative_num_steps,
|
||||
)
|
||||
tl.store(
|
||||
out_cache_loc
|
||||
+ pid * topk * speculative_num_steps
|
||||
+ topk_id * speculative_num_steps
|
||||
+ iter_offest,
|
||||
indices,
|
||||
mask=iter_offest < speculative_num_steps,
|
||||
)
|
||||
save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
|
||||
load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = save_offset < kv_end
|
||||
data = tl.load(out_cache_ptr + load_offset, mask=mask)
|
||||
tl.store(token_pool + save_offset, data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -748,23 +629,20 @@ def generate_draft_decode_kv_indices(
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
positions,
|
||||
num_seqs: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
pool_len: tl.constexpr,
|
||||
kv_indices_stride: tl.constexpr,
|
||||
kv_indptr_stride: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
iter_upper: tl.constexpr,
|
||||
num_tokens_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
iters = tl.program_id(axis=0)
|
||||
bid = tl.program_id(axis=1)
|
||||
topk_id = tl.program_id(axis=2)
|
||||
|
||||
num_steps = tl.num_programs(axis=0)
|
||||
num_seqs = tl.num_programs(axis=1)
|
||||
topk = tl.num_programs(axis=2)
|
||||
|
||||
kv_indices += kv_indices_stride * iters
|
||||
kv_indptr += kv_indptr_stride * iters
|
||||
iters += 1
|
||||
@@ -774,7 +652,6 @@ def generate_draft_decode_kv_indices(
|
||||
seq_len = tl.load(paged_kernel_lens + bid)
|
||||
cum_seq_len = tl.sum(seq_lens)
|
||||
|
||||
# Update kv_indices
|
||||
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
||||
kv_ptr = kv_indices + kv_offset
|
||||
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
||||
@@ -788,26 +665,10 @@ def generate_draft_decode_kv_indices(
|
||||
kv_offset += BLOCK_SIZE
|
||||
|
||||
extend_offset = tl.arange(0, iter_upper)
|
||||
if page_size == 1 or topk == 1:
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
else:
|
||||
prefix_len = seq_len
|
||||
last_page_len = prefix_len % page_size
|
||||
num_new_pages_per_topk = (
|
||||
last_page_len + num_steps + page_size - 1
|
||||
) // page_size
|
||||
prefix_base = seq_len // page_size * page_size
|
||||
start = (
|
||||
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
||||
)
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + start + extend_offset,
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
||||
|
||||
# Update kv_indptr
|
||||
@@ -846,116 +707,6 @@ def align_evict_mask_to_page_size(
|
||||
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_target_cache_loc(
|
||||
tgt_cache_loc,
|
||||
to_free_slots,
|
||||
accept_length,
|
||||
to_free_num_slots,
|
||||
out_cache_loc,
|
||||
num_verify_tokens: tl.constexpr,
|
||||
num_verify_tokens_upper: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(axis=0)
|
||||
offset = tl.arange(0, num_verify_tokens_upper)
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
|
||||
# write the first part to tgt_cache_loc
|
||||
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
||||
copy_len = tl.load(accept_length + bid) + 1
|
||||
out_cache_loc_row = tl.load(
|
||||
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
||||
)
|
||||
tl.store(
|
||||
tgt_cache_loc + tgt_cache_loc_start + offset,
|
||||
out_cache_loc_row,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
|
||||
# write the second part to to_free_num_pages
|
||||
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
||||
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
||||
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
||||
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
||||
|
||||
copy_len = to_free_num_slots_cur
|
||||
out_cache_loc_row = tl.load(
|
||||
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
tl.store(
|
||||
to_free_slots + to_free_slots_start + offset,
|
||||
out_cache_loc_row,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_src_tgt_cache_loc(
|
||||
seq_lens: torch.Tensor,
|
||||
out_cache_loc: torch.Tensor,
|
||||
accept_index: torch.Tensor,
|
||||
accept_length: torch.Tensor,
|
||||
draft_token_num: int,
|
||||
page_size: int,
|
||||
):
|
||||
src_cache_loc = out_cache_loc[accept_index]
|
||||
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
||||
extended_len = seq_lens + draft_token_num
|
||||
keep_len = torch.minimum(
|
||||
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
||||
extended_len,
|
||||
)
|
||||
to_free_num_slots = extended_len - keep_len
|
||||
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
||||
|
||||
|
||||
@triton.jit
|
||||
def filter_finished_cache_loc_kernel(
|
||||
out_cache_loc,
|
||||
tgt_cache_loc,
|
||||
accept_length,
|
||||
accept_length_filter,
|
||||
bs_upper: tl.constexpr,
|
||||
num_verify_tokens_upper: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(0)
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
|
||||
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||
old_start = tl.sum(accept_length_all) + bid
|
||||
|
||||
accept_length_filter_all = tl.load(
|
||||
accept_length_filter + bs_offset, mask=bs_offset < bid
|
||||
)
|
||||
new_start = tl.sum(accept_length_filter_all)
|
||||
|
||||
copy_len = tl.load(accept_length_filter + bid)
|
||||
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
||||
value = tl.load(
|
||||
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
||||
)
|
||||
tl.store(
|
||||
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def create_accept_length_filter(
|
||||
accept_length: torch.Tensor,
|
||||
unfinished_index_device: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
):
|
||||
accept_length_filter = torch.zeros_like(accept_length)
|
||||
accept_length_filter[unfinished_index_device] = (
|
||||
accept_length[unfinished_index_device] + 1
|
||||
)
|
||||
seq_lens.add_(accept_length + 1)
|
||||
return accept_length_filter
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def select_top_k_tokens(
|
||||
i: int,
|
||||
@@ -1005,16 +756,6 @@ def select_top_k_tokens(
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
def fast_topk_torch(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
max_value, max_index = torch.max(values, dim=dim)
|
||||
return max_value.unsqueeze(1), max_index.unsqueeze(1)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
@@ -1024,35 +765,15 @@ def _generate_simulated_accept_index(
|
||||
spec_steps,
|
||||
):
|
||||
simulate_acc_len_float = float(simulate_acc_len)
|
||||
if SIMULATE_ACC_METHOD == "multinomial":
|
||||
simulated_values = torch.normal(
|
||||
mean=simulate_acc_len_float,
|
||||
std=1.0,
|
||||
size=(1,),
|
||||
device="cpu",
|
||||
)
|
||||
# clamp simulated values to be between 1 and self.spec_steps
|
||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
||||
simulate_acc_len = int(simulated_values.round().item())
|
||||
elif SIMULATE_ACC_METHOD == "match-expected":
|
||||
# multinomial sampling does not match the expected length
|
||||
# we keep it for the sake of compatibility of existing tests
|
||||
# but it's better to use "match-expected" for the cases that need to
|
||||
# match the expected length, One caveat is that this will only sample
|
||||
# either round down or round up of the expected length
|
||||
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
||||
lower = int(simulate_acc_len_float // 1)
|
||||
upper = lower + 1 if lower < spec_steps + 1 else lower
|
||||
if lower == upper:
|
||||
simulate_acc_len = lower
|
||||
else:
|
||||
weight_upper = simulate_acc_len_float - lower
|
||||
weight_lower = 1.0 - weight_upper
|
||||
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
simulate_acc_len = lower if sampled_index == 0 else upper
|
||||
else:
|
||||
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
||||
simulated_values = torch.normal(
|
||||
mean=simulate_acc_len_float,
|
||||
std=1.0,
|
||||
size=(1,),
|
||||
device="cpu",
|
||||
)
|
||||
# clamp simulated values to be between 1 and self.spec_steps
|
||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
|
||||
simulate_acc_len = int(simulated_values.round().item())
|
||||
|
||||
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
||||
sim_accept_index = torch.full(
|
||||
@@ -1143,9 +864,9 @@ def generate_token_bitmask(
|
||||
"""
|
||||
Generate the logit mask for structured output.
|
||||
Draft model's token can be either valid or invalid with respect to the grammar.
|
||||
We need to perform DFS to
|
||||
1. figure out which tokens are accepted by the grammar.
|
||||
2. if so, what is the corresponding logit mask.
|
||||
We need to perform DFS to figure out:
|
||||
1. which tokens are accepted by the grammar
|
||||
2. what is the corresponding logit mask.
|
||||
"""
|
||||
|
||||
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
||||
@@ -1162,7 +883,6 @@ def generate_token_bitmask(
|
||||
device="cpu",
|
||||
)
|
||||
grammar = req.grammar
|
||||
s = time.perf_counter()
|
||||
traverse_tree(
|
||||
retrieve_next_token_cpu[i],
|
||||
retrieve_next_sibling_cpu[i],
|
||||
@@ -1172,12 +892,6 @@ def generate_token_bitmask(
|
||||
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
||||
],
|
||||
)
|
||||
tree_traverse_time = time.perf_counter() - s
|
||||
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Bit mask generation took {tree_traverse_time} seconds with "
|
||||
f"grammar: {req.grammar}"
|
||||
)
|
||||
|
||||
verify_input.grammar = grammar
|
||||
return allocate_token_bitmask
|
||||
|
||||
@@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
fast_topk,
|
||||
generate_token_bitmask,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import segment_packbits
|
||||
@@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
# Some dummy tensors
|
||||
self.num_new_pages_per_topk = torch.empty(
|
||||
(), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
@@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
# Capture extend
|
||||
@@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepted,
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
@@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Allocate cache locations
|
||||
# Layout of the out_cache_loc
|
||||
# [ topk 0 ] [ topk 1 ]
|
||||
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
||||
if self.page_size == 1:
|
||||
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
||||
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
|
||||
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
||||
)
|
||||
else:
|
||||
if self.topk == 1:
|
||||
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
prefix_lens = batch.seq_lens
|
||||
seq_lens = prefix_lens + self.speculative_num_steps
|
||||
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||
else:
|
||||
# In this case, the last partial page needs to be duplicated.
|
||||
@@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker):
|
||||
# "x" means speculative draft tokens
|
||||
# "." means padded tokens
|
||||
|
||||
# TODO(lmzheng): The current implementation is still a fake support
|
||||
# for page size > 1. In the `assign_draft_cache_locs` below,
|
||||
# we directly move the indices instead of the real kv cache.
|
||||
# This only works when the kernel backend runs with page size = 1.
|
||||
# If the kernel backend runs with page size > 1, we need to
|
||||
# duplicate the real KV cache. The overhead of duplicating KV
|
||||
# cache seems okay because the draft KV cache only has one layer.
|
||||
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
|
||||
|
||||
(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.num_new_pages_per_topk,
|
||||
self.extend_lens,
|
||||
) = get_last_loc_large_page_size_large_top_k(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
self.speculative_num_steps,
|
||||
self.topk,
|
||||
self.page_size,
|
||||
# TODO: fuse these ops
|
||||
prefix_lens = batch.seq_lens
|
||||
last_page_lens = prefix_lens % self.page_size
|
||||
num_new_pages = (
|
||||
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
||||
) // self.page_size
|
||||
seq_lens = (
|
||||
prefix_lens // self.page_size * self.page_size
|
||||
+ num_new_pages * (self.page_size * self.topk)
|
||||
)
|
||||
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
||||
raise NotImplementedError(
|
||||
"page_size > 1 and top_k > 1 are not supported."
|
||||
)
|
||||
# TODO: Support page_size > 1 and top_k > 1
|
||||
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
||||
# 2. Modify generate_draft_decode_kv_indices accordingly
|
||||
|
||||
# TODO(lmzheng): remove this device sync
|
||||
extend_num_tokens = torch.sum(self.extend_lens).item()
|
||||
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||
batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens,
|
||||
@@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
self.extend_lens,
|
||||
self.num_new_pages_per_topk,
|
||||
out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.page_size,
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
)
|
||||
|
||||
if self.page_size > 1 and self.topk > 1:
|
||||
# Remove padded slots
|
||||
out_cache_loc = out_cache_loc[
|
||||
: num_seqs * self.topk * self.speculative_num_steps
|
||||
]
|
||||
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
batch.return_hidden_states = False
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
|
||||
# Get forward batch
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
|
||||
out_cache_loc = out_cache_loc.reshape(
|
||||
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
||||
)
|
||||
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
||||
self.speculative_num_steps, -1
|
||||
)
|
||||
|
||||
# Return values
|
||||
score_list: List[torch.Tensor] = []
|
||||
token_list: List[torch.Tensor] = []
|
||||
@@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Set inputs
|
||||
forward_batch.input_ids = input_ids
|
||||
forward_batch.out_cache_loc = out_cache_loc[i:]
|
||||
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
||||
forward_batch.out_cache_loc = out_cache_loc[
|
||||
:, self.topk * i : self.topk * (i + 1)
|
||||
].flatten()
|
||||
forward_batch.positions.add_(1)
|
||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||
spec_info.hidden_states = hidden_states
|
||||
@@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
if vocab_mask is not None:
|
||||
assert spec_info.grammar is not None
|
||||
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
||||
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
|
||||
# otherwise, this vocab mask will be the one from the previous extend stage
|
||||
# and will be applied to produce wrong results
|
||||
batch.sampling_info.vocab_mask = None
|
||||
|
||||
@@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
||||
|
||||
if batch.return_logprob:
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
if batch.return_logprob:
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||
|
||||
def add_logprob_values(
|
||||
@@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output = res.logits_output
|
||||
top_logprobs_nums = batch.top_logprobs_nums
|
||||
token_ids_logprobs = batch.token_ids_logprobs
|
||||
accepted_indices = res.accepted_indices
|
||||
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
||||
temperatures = batch.sampling_info.temperatures
|
||||
num_draft_tokens = batch.spec_info.draft_token_num
|
||||
# acceptance indices are the indices in a "flattened" batch.
|
||||
# dividing it to num_draft_tokens will yield the actual batch index.
|
||||
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
||||
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits / temperatures, dim=-1
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
batch_next_token_ids = res.verified_id
|
||||
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||
@@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
pt = 0
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
verified_ids = batch_next_token_ids.tolist()
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||
for _ in range(num_tokens):
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||
@@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
@@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]:
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_large_page_size_top_k_1(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
prefix_lens = seq_lens
|
||||
seq_lens = prefix_lens + speculative_num_steps
|
||||
last_loc = get_last_loc(
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
return prefix_lens, seq_lens, last_loc
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_large_page_size_large_top_k(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
speculative_num_steps: int,
|
||||
topk: int,
|
||||
page_size: int,
|
||||
):
|
||||
prefix_lens = seq_lens
|
||||
last_page_lens = prefix_lens % page_size
|
||||
num_new_pages_per_topk = (
|
||||
last_page_lens + speculative_num_steps + page_size - 1
|
||||
) // page_size
|
||||
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
|
||||
page_size * topk
|
||||
)
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
last_loc = get_last_loc(
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
|
||||
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
|
||||
|
||||
Reference in New Issue
Block a user