[EAGLE] Refactor code for page size > 1 & more simplifications (#7213)
This commit is contained in:
@@ -1049,14 +1049,13 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
kv_indices_buffer,
|
kv_indices_buffer,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
num_seqs,
|
|
||||||
self.topk,
|
|
||||||
self.pool_len,
|
self.pool_len,
|
||||||
kv_indices_buffer.shape[1],
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
next_power_of_2(num_seqs),
|
next_power_of_2(num_seqs),
|
||||||
next_power_of_2(self.speculative_num_steps),
|
next_power_of_2(self.speculative_num_steps),
|
||||||
next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert forward_batch.spec_info is not None
|
assert forward_batch.spec_info is not None
|
||||||
|
|||||||
@@ -789,6 +789,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
|
|
||||||
# Cached variables for generate_draft_decode_kv_indices
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
|
self.page_size = model_runner.server_args.page_size
|
||||||
|
|
||||||
def common_template(
|
def common_template(
|
||||||
self,
|
self,
|
||||||
@@ -809,14 +810,13 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
kv_indices_buffer,
|
kv_indices_buffer,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
num_seqs,
|
|
||||||
self.topk,
|
|
||||||
self.pool_len,
|
self.pool_len,
|
||||||
kv_indices_buffer.shape[1],
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
next_power_of_2(num_seqs),
|
next_power_of_2(num_seqs),
|
||||||
next_power_of_2(self.speculative_num_steps),
|
next_power_of_2(self.speculative_num_steps),
|
||||||
next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert forward_batch.spec_info is not None
|
assert forward_batch.spec_info is not None
|
||||||
|
|||||||
@@ -2,9 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Support attention backend for FlashMLA.
|
Support attention backend for FlashMLA.
|
||||||
|
|
||||||
#TODO
|
|
||||||
Enable speculative sampling in FlashMLA
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|||||||
@@ -784,14 +784,13 @@ class TritonMultiStepDraftBackend:
|
|||||||
kv_indices_buffer,
|
kv_indices_buffer,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
num_seqs,
|
|
||||||
self.topk,
|
|
||||||
self.pool_len,
|
self.pool_len,
|
||||||
kv_indices_buffer.shape[1],
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
next_power_of_2(num_seqs),
|
next_power_of_2(num_seqs),
|
||||||
next_power_of_2(self.speculative_num_steps),
|
next_power_of_2(self.speculative_num_steps),
|
||||||
next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
|
|||||||
@@ -294,6 +294,19 @@ class MHATokenToKVPool(KVCache):
|
|||||||
for _ in range(self.layer_num)
|
for _ in range(self.layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.data_ptrs = torch.tensor(
|
||||||
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||||
|
dtype=torch.uint64,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.data_strides = torch.tensor(
|
||||||
|
[
|
||||||
|
np.prod(x.shape[1:]) * x.dtype.itemsize
|
||||||
|
for x in self.k_buffer + self.v_buffer
|
||||||
|
],
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
def _clear_buffers(self):
|
def _clear_buffers(self):
|
||||||
del self.k_buffer
|
del self.k_buffer
|
||||||
del self.v_buffer
|
del self.v_buffer
|
||||||
@@ -451,6 +464,16 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
|
|
||||||
|
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
||||||
|
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
||||||
|
self.data_ptrs,
|
||||||
|
self.data_strides,
|
||||||
|
tgt_loc,
|
||||||
|
src_loc,
|
||||||
|
len(tgt_loc),
|
||||||
|
next_power_of_2(len(tgt_loc)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def set_mla_kv_buffer_kernel(
|
def set_mla_kv_buffer_kernel(
|
||||||
@@ -741,3 +764,41 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
|
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def copy_all_layer_kv_cache(
|
||||||
|
data_ptrs,
|
||||||
|
strides,
|
||||||
|
tgt_loc_ptr,
|
||||||
|
src_loc_ptr,
|
||||||
|
num_locs,
|
||||||
|
num_locs_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 128
|
||||||
|
|
||||||
|
bid = tl.program_id(0)
|
||||||
|
stride = tl.load(strides + bid)
|
||||||
|
|
||||||
|
data_ptr = tl.load(data_ptrs + bid)
|
||||||
|
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
||||||
|
|
||||||
|
num_locs_offset = tl.arange(0, num_locs_upper)
|
||||||
|
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||||
|
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||||
|
|
||||||
|
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
||||||
|
# because this copy is an inplace operation.
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
|
||||||
|
value = tl.load(
|
||||||
|
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
||||||
|
value,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|||||||
@@ -67,8 +67,6 @@ class EagleDraftInput:
|
|||||||
kv_indptr: torch.Tensor = None
|
kv_indptr: torch.Tensor = None
|
||||||
kv_indices: torch.Tensor = None
|
kv_indices: torch.Tensor = None
|
||||||
|
|
||||||
all_padding_lens: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||||
# Prefill only generate 1 token.
|
# Prefill only generate 1 token.
|
||||||
assert len(self.verified_id) == len(batch.seq_lens)
|
assert len(self.verified_id) == len(batch.seq_lens)
|
||||||
@@ -93,6 +91,7 @@ class EagleDraftInput:
|
|||||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||||
batch.return_logprob = False
|
batch.return_logprob = False
|
||||||
|
batch.return_hidden_states = False
|
||||||
|
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
self.accept_length.add_(1)
|
self.accept_length.add_(1)
|
||||||
@@ -116,10 +115,8 @@ class EagleDraftInput:
|
|||||||
req_to_token: torch.Tensor,
|
req_to_token: torch.Tensor,
|
||||||
):
|
):
|
||||||
bs = self.accept_length.numel()
|
bs = self.accept_length.numel()
|
||||||
|
|
||||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||||
|
|
||||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
|
||||||
@@ -139,7 +136,6 @@ class EagleDraftInput:
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
req_to_token.size(1),
|
req_to_token.size(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||||
|
|
||||||
def filter_batch(self, new_indices: torch.Tensor):
|
def filter_batch(self, new_indices: torch.Tensor):
|
||||||
@@ -270,7 +266,7 @@ class EagleVerifyInput:
|
|||||||
logits_output: torch.Tensor,
|
logits_output: torch.Tensor,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
vocab_mask: Optional[torch.Tensor] = None,
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Verify and find accepted tokens based on logits output and batch
|
Verify and find accepted tokens based on logits output and batch
|
||||||
@@ -294,6 +290,14 @@ class EagleVerifyInput:
|
|||||||
)
|
)
|
||||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# Apply the custom logit processors if registered in the sampling info.
|
||||||
|
if sampling_info.has_custom_logit_processor:
|
||||||
|
apply_custom_logit_processor(
|
||||||
|
logits_output.next_token_logits,
|
||||||
|
sampling_info,
|
||||||
|
num_tokens_in_batch=self.draft_token_num,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply penalty
|
# Apply penalty
|
||||||
if sampling_info.penalizer_orchestrator.is_required:
|
if sampling_info.penalizer_orchestrator.is_required:
|
||||||
# This is a relaxed version of penalties for speculative decoding.
|
# This is a relaxed version of penalties for speculative decoding.
|
||||||
@@ -355,7 +359,13 @@ class EagleVerifyInput:
|
|||||||
draft_probs = torch.zeros(
|
draft_probs = torch.zeros(
|
||||||
target_probs.shape, dtype=torch.float32, device="cuda"
|
target_probs.shape, dtype=torch.float32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# coins for rejection sampling
|
||||||
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
||||||
|
# coins for final sampling
|
||||||
|
coins_for_final_sampling = torch.rand(
|
||||||
|
(bs,), dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
tree_speculative_sampling_target_only(
|
tree_speculative_sampling_target_only(
|
||||||
predicts=predict, # mutable
|
predicts=predict, # mutable
|
||||||
accept_index=accept_index, # mutable
|
accept_index=accept_index, # mutable
|
||||||
@@ -365,6 +375,7 @@ class EagleVerifyInput:
|
|||||||
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
retrive_next_token=self.retrive_next_token.to(torch.int32),
|
||||||
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
|
||||||
uniform_samples=coins,
|
uniform_samples=coins,
|
||||||
|
# uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||||
target_probs=target_probs,
|
target_probs=target_probs,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
threshold_single=global_server_args_dict[
|
threshold_single=global_server_args_dict[
|
||||||
@@ -387,8 +398,8 @@ class EagleVerifyInput:
|
|||||||
spec_steps=self.spec_steps,
|
spec_steps=self.spec_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_accept_index = []
|
|
||||||
unfinished_index = []
|
unfinished_index = []
|
||||||
|
unfinished_accept_index = []
|
||||||
accept_index_cpu = accept_index.tolist()
|
accept_index_cpu = accept_index.tolist()
|
||||||
predict_cpu = predict.tolist()
|
predict_cpu = predict.tolist()
|
||||||
has_finished = False
|
has_finished = False
|
||||||
@@ -396,12 +407,10 @@ class EagleVerifyInput:
|
|||||||
# Iterate every accepted token and check if req has finished after append the token
|
# Iterate every accepted token and check if req has finished after append the token
|
||||||
# should be checked BEFORE free kv cache slots
|
# should be checked BEFORE free kv cache slots
|
||||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||||
new_accept_index_ = []
|
|
||||||
for j, idx in enumerate(accept_index_row):
|
for j, idx in enumerate(accept_index_row):
|
||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
id = predict_cpu[idx]
|
id = predict_cpu[idx]
|
||||||
# if not found_finished:
|
|
||||||
req.output_ids.append(id)
|
req.output_ids.append(id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
if req.finished():
|
if req.finished():
|
||||||
@@ -410,8 +419,6 @@ class EagleVerifyInput:
|
|||||||
accept_index[i, j + 1 :] = -1
|
accept_index[i, j + 1 :] = -1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
new_accept_index_.append(idx)
|
|
||||||
# update grammar state
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
try:
|
try:
|
||||||
req.grammar.accept_token(id)
|
req.grammar.accept_token(id)
|
||||||
@@ -421,50 +428,104 @@ class EagleVerifyInput:
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
if not req.finished():
|
if not req.finished():
|
||||||
new_accept_index.extend(new_accept_index_)
|
|
||||||
unfinished_index.append(i)
|
unfinished_index.append(i)
|
||||||
|
if idx == -1:
|
||||||
|
unfinished_accept_index.append(accept_index[i, :j])
|
||||||
|
else:
|
||||||
|
unfinished_accept_index.append(accept_index[i])
|
||||||
req.spec_verify_ct += 1
|
req.spec_verify_ct += 1
|
||||||
|
|
||||||
if has_finished:
|
if has_finished:
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
|
||||||
# Free the KV cache for unaccepted tokens
|
# Free the KV cache for unaccepted tokens
|
||||||
|
# TODO: fuse them
|
||||||
accept_index = accept_index[accept_index != -1]
|
accept_index = accept_index[accept_index != -1]
|
||||||
verified_id = predict[accept_index]
|
verified_id = predict[accept_index]
|
||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
evict_mask[accept_index] = False
|
evict_mask[accept_index] = False
|
||||||
|
|
||||||
if page_size != 1:
|
if page_size == 1:
|
||||||
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
# TODO: boolean array index leads to a device sync. Remove it.
|
||||||
batch.seq_lens,
|
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||||
evict_mask,
|
else:
|
||||||
page_size,
|
if self.topk == 1:
|
||||||
self.draft_token_num,
|
# Only evict full empty page. Do not evict partial empty page
|
||||||
next_power_of_2(self.draft_token_num),
|
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
||||||
)
|
batch.seq_lens,
|
||||||
|
evict_mask,
|
||||||
|
page_size,
|
||||||
|
self.draft_token_num,
|
||||||
|
next_power_of_2(self.draft_token_num),
|
||||||
|
)
|
||||||
|
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||||
|
else:
|
||||||
|
# Shift the accepted tokens to the beginning.
|
||||||
|
# Only evict the last part
|
||||||
|
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
||||||
|
batch.seq_lens,
|
||||||
|
batch.out_cache_loc,
|
||||||
|
accept_index,
|
||||||
|
accept_length,
|
||||||
|
self.draft_token_num,
|
||||||
|
page_size,
|
||||||
|
)
|
||||||
|
to_free_slots = torch.empty(
|
||||||
|
(to_free_num_slots.sum().item(),),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=to_free_num_slots.device,
|
||||||
|
)
|
||||||
|
|
||||||
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
||||||
|
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
||||||
|
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
||||||
|
# to_free_slots: [ 2, 5, 7 8]
|
||||||
|
# to_free_slots also needs to be page-aligned without the first partial page
|
||||||
|
#
|
||||||
|
# split each row of out_cache_loc into two parts.
|
||||||
|
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
||||||
|
# 2. the second part goes to to_free_slots.
|
||||||
|
get_target_cache_loc[(bs,)](
|
||||||
|
tgt_cache_loc,
|
||||||
|
to_free_slots,
|
||||||
|
accept_length,
|
||||||
|
to_free_num_slots,
|
||||||
|
batch.out_cache_loc,
|
||||||
|
self.draft_token_num,
|
||||||
|
next_power_of_2(self.draft_token_num),
|
||||||
|
next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Free the kv cache
|
||||||
|
token_to_kv_pool_allocator.free(to_free_slots)
|
||||||
|
|
||||||
|
# Copy the kv cache
|
||||||
|
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
||||||
|
tgt_cache_loc, src_cache_loc
|
||||||
|
)
|
||||||
|
|
||||||
# Construct EagleVerifyOutput
|
# Construct EagleVerifyOutput
|
||||||
if not has_finished:
|
if not has_finished:
|
||||||
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
if page_size == 1 or self.topk == 1:
|
||||||
assign_req_to_token_pool[(bs,)](
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
||||||
batch.req_pool_indices,
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_pool_indices,
|
||||||
batch.seq_lens,
|
batch.req_to_token_pool.req_to_token,
|
||||||
batch.seq_lens + accept_length + 1,
|
batch.seq_lens,
|
||||||
batch.out_cache_loc,
|
batch.seq_lens + accept_length + 1,
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.out_cache_loc,
|
||||||
next_power_of_2(bs),
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
)
|
next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch.out_cache_loc = tgt_cache_loc
|
||||||
batch.seq_lens.add_(accept_length + 1)
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
accept_length_cpu = accept_length.tolist()
|
|
||||||
|
|
||||||
draft_input = EagleDraftInput()
|
draft_input = EagleDraftInput()
|
||||||
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
||||||
draft_input.verified_id = verified_id
|
draft_input.verified_id = verified_id
|
||||||
draft_input.accept_length = accept_length
|
draft_input.accept_length = accept_length
|
||||||
draft_input.accept_length_cpu = accept_length_cpu
|
draft_input.accept_length_cpu = accept_length.tolist()
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||||
|
|
||||||
@@ -472,47 +533,66 @@ class EagleVerifyInput:
|
|||||||
draft_input=draft_input,
|
draft_input=draft_input,
|
||||||
logits_output=logits_output,
|
logits_output=logits_output,
|
||||||
verified_id=verified_id,
|
verified_id=verified_id,
|
||||||
accept_length_per_req_cpu=accept_length_cpu,
|
accept_length_per_req_cpu=draft_input.accept_length_cpu,
|
||||||
accepted_indices=accept_index,
|
accepted_indices=accept_index,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assign_req_to_token_pool[(bs,)](
|
if page_size == 1 or self.topk == 1:
|
||||||
batch.req_pool_indices,
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_pool_indices,
|
||||||
batch.seq_lens,
|
batch.req_to_token_pool.req_to_token,
|
||||||
batch.seq_lens + accept_length + 1,
|
batch.seq_lens,
|
||||||
batch.out_cache_loc[accept_index],
|
batch.seq_lens + accept_length + 1,
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.out_cache_loc[accept_index],
|
||||||
next_power_of_2(bs),
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
)
|
next_power_of_2(bs),
|
||||||
batch.seq_lens.add_(accept_length + 1)
|
)
|
||||||
accept_length_cpu = accept_length.tolist()
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
|
|
||||||
|
accept_length_cpu = accept_length.tolist()
|
||||||
draft_input = EagleDraftInput()
|
draft_input = EagleDraftInput()
|
||||||
if len(new_accept_index) > 0:
|
if len(unfinished_accept_index) > 0:
|
||||||
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
||||||
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
|
unfinished_index_device = torch.tensor(
|
||||||
draft_input.hidden_states = batch.spec_info.hidden_states[
|
unfinished_index, dtype=torch.int64, device=predict.device
|
||||||
new_accept_index
|
)
|
||||||
]
|
draft_input_accept_length_cpu = [
|
||||||
draft_input.verified_id = predict[new_accept_index]
|
|
||||||
draft_input.accept_length_cpu = [
|
|
||||||
accept_length_cpu[i] for i in unfinished_index
|
accept_length_cpu[i] for i in unfinished_index
|
||||||
]
|
]
|
||||||
draft_input.accept_length = accept_length[unfinished_index_device]
|
if page_size == 1 or self.topk == 1:
|
||||||
if has_finished:
|
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
|
||||||
unfinished_index_device
|
|
||||||
]
|
|
||||||
draft_input.req_pool_indices_for_draft_extend = (
|
|
||||||
batch.req_pool_indices[unfinished_index_device]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
batch.out_cache_loc = torch.empty(
|
||||||
draft_input.req_pool_indices_for_draft_extend = (
|
len(unfinished_index) + sum(draft_input_accept_length_cpu),
|
||||||
batch.req_pool_indices
|
dtype=torch.int64,
|
||||||
|
device=predict.device,
|
||||||
)
|
)
|
||||||
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
accept_length_filter = create_accept_length_filter(
|
||||||
|
accept_length,
|
||||||
|
unfinished_index_device,
|
||||||
|
batch.seq_lens,
|
||||||
|
)
|
||||||
|
filter_finished_cache_loc_kernel[(bs,)](
|
||||||
|
batch.out_cache_loc,
|
||||||
|
tgt_cache_loc,
|
||||||
|
accept_length,
|
||||||
|
accept_length_filter,
|
||||||
|
next_power_of_2(bs),
|
||||||
|
next_power_of_2(self.draft_token_num),
|
||||||
|
)
|
||||||
|
|
||||||
|
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||||
|
unfinished_accept_index
|
||||||
|
]
|
||||||
|
draft_input.verified_id = predict[unfinished_accept_index]
|
||||||
|
draft_input.accept_length_cpu = draft_input_accept_length_cpu
|
||||||
|
draft_input.accept_length = accept_length[unfinished_index_device]
|
||||||
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||||
|
unfinished_index_device
|
||||||
|
]
|
||||||
|
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
||||||
|
unfinished_index_device
|
||||||
|
]
|
||||||
|
|
||||||
return EagleVerifyOutput(
|
return EagleVerifyOutput(
|
||||||
draft_input=draft_input,
|
draft_input=draft_input,
|
||||||
@@ -589,36 +669,75 @@ def assign_draft_cache_locs(
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
num_new_pages_per_topk,
|
||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
pool_len: tl.constexpr,
|
pool_len: tl.constexpr,
|
||||||
topk: tl.constexpr,
|
topk: tl.constexpr,
|
||||||
speculative_num_steps: tl.constexpr,
|
speculative_num_steps: tl.constexpr,
|
||||||
page_size: tl.constexpr,
|
page_size: tl.constexpr,
|
||||||
|
bs_upper: tl.constexpr,
|
||||||
|
iter_upper: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 32
|
BLOCK_SIZE: tl.constexpr = 128
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
kv_start = tl.load(seq_lens + pid)
|
|
||||||
|
|
||||||
if page_size == 1 or topk == 1:
|
if page_size == 1 or topk == 1:
|
||||||
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
copy_len = topk * speculative_num_steps
|
||||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||||
else:
|
else:
|
||||||
prefix_len = tl.load(seq_lens + pid)
|
bs_offset = tl.arange(0, bs_upper)
|
||||||
last_page_len = prefix_len % page_size
|
copy_len = tl.load(extend_lens + pid)
|
||||||
num_new_page = (
|
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
||||||
last_page_len + speculative_num_steps + page_size - 1
|
out_cache_ptr = out_cache_loc + cum_copy_len
|
||||||
) // page_size
|
|
||||||
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
|
||||||
|
|
||||||
|
# Part 1: Copy from out_cache_loc to req_to_token
|
||||||
|
kv_start = tl.load(seq_lens + pid)
|
||||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||||
|
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
||||||
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
|
||||||
for i in range(num_loop):
|
for i in range(num_loop):
|
||||||
save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
|
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
mask = copy_offset < copy_len
|
||||||
mask = save_offset < kv_end
|
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
||||||
data = tl.load(out_cache_ptr + load_offset, mask=mask)
|
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
||||||
tl.store(token_pool + save_offset, data, mask=mask)
|
|
||||||
|
if page_size == 1 or topk == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Part 2: Copy the indices for the last partial page
|
||||||
|
prefix_len = tl.load(seq_lens + pid)
|
||||||
|
last_page_len = prefix_len % page_size
|
||||||
|
offsets = tl.arange(0, page_size)
|
||||||
|
mask = offsets < last_page_len
|
||||||
|
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
||||||
|
prefix_base = token_pool + prefix_len - last_page_len
|
||||||
|
|
||||||
|
for topk_id in range(topk):
|
||||||
|
value = tl.load(prefix_base + offsets, mask=mask)
|
||||||
|
tl.store(
|
||||||
|
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
||||||
|
value,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Part 3: Remove the padding in out_cache_loc
|
||||||
|
iter_offest = tl.arange(0, iter_upper)
|
||||||
|
for topk_id in range(topk):
|
||||||
|
indices = tl.load(
|
||||||
|
prefix_base
|
||||||
|
+ topk_id * num_new_pages_per_topk_ * page_size
|
||||||
|
+ last_page_len
|
||||||
|
+ iter_offest,
|
||||||
|
mask=iter_offest < speculative_num_steps,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
out_cache_loc
|
||||||
|
+ pid * topk * speculative_num_steps
|
||||||
|
+ topk_id * speculative_num_steps
|
||||||
|
+ iter_offest,
|
||||||
|
indices,
|
||||||
|
mask=iter_offest < speculative_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -629,20 +748,23 @@ def generate_draft_decode_kv_indices(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
positions,
|
positions,
|
||||||
num_seqs: tl.constexpr,
|
|
||||||
topk: tl.constexpr,
|
|
||||||
pool_len: tl.constexpr,
|
pool_len: tl.constexpr,
|
||||||
kv_indices_stride: tl.constexpr,
|
kv_indices_stride: tl.constexpr,
|
||||||
kv_indptr_stride: tl.constexpr,
|
kv_indptr_stride: tl.constexpr,
|
||||||
bs_upper: tl.constexpr,
|
bs_upper: tl.constexpr,
|
||||||
iter_upper: tl.constexpr,
|
iter_upper: tl.constexpr,
|
||||||
num_tokens_upper: tl.constexpr,
|
num_tokens_upper: tl.constexpr,
|
||||||
|
page_size: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 128
|
BLOCK_SIZE: tl.constexpr = 128
|
||||||
iters = tl.program_id(axis=0)
|
iters = tl.program_id(axis=0)
|
||||||
bid = tl.program_id(axis=1)
|
bid = tl.program_id(axis=1)
|
||||||
topk_id = tl.program_id(axis=2)
|
topk_id = tl.program_id(axis=2)
|
||||||
|
|
||||||
|
num_steps = tl.num_programs(axis=0)
|
||||||
|
num_seqs = tl.num_programs(axis=1)
|
||||||
|
topk = tl.num_programs(axis=2)
|
||||||
|
|
||||||
kv_indices += kv_indices_stride * iters
|
kv_indices += kv_indices_stride * iters
|
||||||
kv_indptr += kv_indptr_stride * iters
|
kv_indptr += kv_indptr_stride * iters
|
||||||
iters += 1
|
iters += 1
|
||||||
@@ -652,6 +774,7 @@ def generate_draft_decode_kv_indices(
|
|||||||
seq_len = tl.load(paged_kernel_lens + bid)
|
seq_len = tl.load(paged_kernel_lens + bid)
|
||||||
cum_seq_len = tl.sum(seq_lens)
|
cum_seq_len = tl.sum(seq_lens)
|
||||||
|
|
||||||
|
# Update kv_indices
|
||||||
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
||||||
kv_ptr = kv_indices + kv_offset
|
kv_ptr = kv_indices + kv_offset
|
||||||
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
||||||
@@ -665,10 +788,26 @@ def generate_draft_decode_kv_indices(
|
|||||||
kv_offset += BLOCK_SIZE
|
kv_offset += BLOCK_SIZE
|
||||||
|
|
||||||
extend_offset = tl.arange(0, iter_upper)
|
extend_offset = tl.arange(0, iter_upper)
|
||||||
extend_data = tl.load(
|
if page_size == 1 or topk == 1:
|
||||||
token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
|
extend_data = tl.load(
|
||||||
mask=extend_offset < iters,
|
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
||||||
)
|
mask=extend_offset < iters,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefix_len = seq_len
|
||||||
|
last_page_len = prefix_len % page_size
|
||||||
|
num_new_pages_per_topk = (
|
||||||
|
last_page_len + num_steps + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
prefix_base = seq_len // page_size * page_size
|
||||||
|
start = (
|
||||||
|
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
||||||
|
)
|
||||||
|
extend_data = tl.load(
|
||||||
|
token_pool_ptr + start + extend_offset,
|
||||||
|
mask=extend_offset < iters,
|
||||||
|
)
|
||||||
|
|
||||||
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
||||||
|
|
||||||
# Update kv_indptr
|
# Update kv_indptr
|
||||||
@@ -707,6 +846,116 @@ def align_evict_mask_to_page_size(
|
|||||||
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def get_target_cache_loc(
|
||||||
|
tgt_cache_loc,
|
||||||
|
to_free_slots,
|
||||||
|
accept_length,
|
||||||
|
to_free_num_slots,
|
||||||
|
out_cache_loc,
|
||||||
|
num_verify_tokens: tl.constexpr,
|
||||||
|
num_verify_tokens_upper: tl.constexpr,
|
||||||
|
bs_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
bid = tl.program_id(axis=0)
|
||||||
|
offset = tl.arange(0, num_verify_tokens_upper)
|
||||||
|
bs_offset = tl.arange(0, bs_upper)
|
||||||
|
|
||||||
|
# write the first part to tgt_cache_loc
|
||||||
|
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||||
|
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
||||||
|
copy_len = tl.load(accept_length + bid) + 1
|
||||||
|
out_cache_loc_row = tl.load(
|
||||||
|
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
tgt_cache_loc + tgt_cache_loc_start + offset,
|
||||||
|
out_cache_loc_row,
|
||||||
|
mask=offset < copy_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# write the second part to to_free_num_pages
|
||||||
|
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
||||||
|
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
||||||
|
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
||||||
|
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
||||||
|
|
||||||
|
copy_len = to_free_num_slots_cur
|
||||||
|
out_cache_loc_row = tl.load(
|
||||||
|
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
||||||
|
mask=offset < copy_len,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
to_free_slots + to_free_slots_start + offset,
|
||||||
|
out_cache_loc_row,
|
||||||
|
mask=offset < copy_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def get_src_tgt_cache_loc(
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
out_cache_loc: torch.Tensor,
|
||||||
|
accept_index: torch.Tensor,
|
||||||
|
accept_length: torch.Tensor,
|
||||||
|
draft_token_num: int,
|
||||||
|
page_size: int,
|
||||||
|
):
|
||||||
|
src_cache_loc = out_cache_loc[accept_index]
|
||||||
|
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
||||||
|
extended_len = seq_lens + draft_token_num
|
||||||
|
keep_len = torch.minimum(
|
||||||
|
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
||||||
|
extended_len,
|
||||||
|
)
|
||||||
|
to_free_num_slots = extended_len - keep_len
|
||||||
|
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def filter_finished_cache_loc_kernel(
|
||||||
|
out_cache_loc,
|
||||||
|
tgt_cache_loc,
|
||||||
|
accept_length,
|
||||||
|
accept_length_filter,
|
||||||
|
bs_upper: tl.constexpr,
|
||||||
|
num_verify_tokens_upper: tl.constexpr,
|
||||||
|
):
|
||||||
|
bid = tl.program_id(0)
|
||||||
|
bs_offset = tl.arange(0, bs_upper)
|
||||||
|
|
||||||
|
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||||
|
old_start = tl.sum(accept_length_all) + bid
|
||||||
|
|
||||||
|
accept_length_filter_all = tl.load(
|
||||||
|
accept_length_filter + bs_offset, mask=bs_offset < bid
|
||||||
|
)
|
||||||
|
new_start = tl.sum(accept_length_filter_all)
|
||||||
|
|
||||||
|
copy_len = tl.load(accept_length_filter + bid)
|
||||||
|
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
||||||
|
value = tl.load(
|
||||||
|
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def create_accept_length_filter(
|
||||||
|
accept_length: torch.Tensor,
|
||||||
|
unfinished_index_device: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
):
|
||||||
|
accept_length_filter = torch.zeros_like(accept_length)
|
||||||
|
accept_length_filter[unfinished_index_device] = (
|
||||||
|
accept_length[unfinished_index_device] + 1
|
||||||
|
)
|
||||||
|
seq_lens.add_(accept_length + 1)
|
||||||
|
return accept_length_filter
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True)
|
||||||
def select_top_k_tokens(
|
def select_top_k_tokens(
|
||||||
i: int,
|
i: int,
|
||||||
@@ -756,6 +1005,16 @@ def select_top_k_tokens(
|
|||||||
return input_ids, hidden_states, scores, tree_info
|
return input_ids, hidden_states, scores, tree_info
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk_torch(values, topk, dim):
|
||||||
|
if topk == 1:
|
||||||
|
# Use max along the specified dimension to get both value and index
|
||||||
|
max_value, max_index = torch.max(values, dim=dim)
|
||||||
|
return max_value.unsqueeze(1), max_index.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# Use topk for efficiency with larger k values
|
||||||
|
return torch.topk(values, topk, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
def _generate_simulated_accept_index(
|
def _generate_simulated_accept_index(
|
||||||
accept_index,
|
accept_index,
|
||||||
predict,
|
predict,
|
||||||
@@ -765,15 +1024,35 @@ def _generate_simulated_accept_index(
|
|||||||
spec_steps,
|
spec_steps,
|
||||||
):
|
):
|
||||||
simulate_acc_len_float = float(simulate_acc_len)
|
simulate_acc_len_float = float(simulate_acc_len)
|
||||||
simulated_values = torch.normal(
|
if SIMULATE_ACC_METHOD == "multinomial":
|
||||||
mean=simulate_acc_len_float,
|
simulated_values = torch.normal(
|
||||||
std=1.0,
|
mean=simulate_acc_len_float,
|
||||||
size=(1,),
|
std=1.0,
|
||||||
device="cpu",
|
size=(1,),
|
||||||
)
|
device="cpu",
|
||||||
# clamp simulated values to be between 1 and self.spec_steps
|
)
|
||||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
|
# clamp simulated values to be between 1 and self.spec_steps
|
||||||
simulate_acc_len = int(simulated_values.round().item())
|
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
||||||
|
simulate_acc_len = int(simulated_values.round().item())
|
||||||
|
elif SIMULATE_ACC_METHOD == "match-expected":
|
||||||
|
# multinomial sampling does not match the expected length
|
||||||
|
# we keep it for the sake of compatibility of existing tests
|
||||||
|
# but it's better to use "match-expected" for the cases that need to
|
||||||
|
# match the expected length, One caveat is that this will only sample
|
||||||
|
# either round down or round up of the expected length
|
||||||
|
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
||||||
|
lower = int(simulate_acc_len_float // 1)
|
||||||
|
upper = lower + 1 if lower < spec_steps + 1 else lower
|
||||||
|
if lower == upper:
|
||||||
|
simulate_acc_len = lower
|
||||||
|
else:
|
||||||
|
weight_upper = simulate_acc_len_float - lower
|
||||||
|
weight_lower = 1.0 - weight_upper
|
||||||
|
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
||||||
|
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||||
|
simulate_acc_len = lower if sampled_index == 0 else upper
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
||||||
|
|
||||||
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
||||||
sim_accept_index = torch.full(
|
sim_accept_index = torch.full(
|
||||||
@@ -864,9 +1143,9 @@ def generate_token_bitmask(
|
|||||||
"""
|
"""
|
||||||
Generate the logit mask for structured output.
|
Generate the logit mask for structured output.
|
||||||
Draft model's token can be either valid or invalid with respect to the grammar.
|
Draft model's token can be either valid or invalid with respect to the grammar.
|
||||||
We need to perform DFS to figure out:
|
We need to perform DFS to
|
||||||
1. which tokens are accepted by the grammar
|
1. figure out which tokens are accepted by the grammar.
|
||||||
2. what is the corresponding logit mask.
|
2. if so, what is the corresponding logit mask.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
||||||
@@ -883,6 +1162,7 @@ def generate_token_bitmask(
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
)
|
)
|
||||||
grammar = req.grammar
|
grammar = req.grammar
|
||||||
|
s = time.perf_counter()
|
||||||
traverse_tree(
|
traverse_tree(
|
||||||
retrieve_next_token_cpu[i],
|
retrieve_next_token_cpu[i],
|
||||||
retrieve_next_sibling_cpu[i],
|
retrieve_next_sibling_cpu[i],
|
||||||
@@ -892,6 +1172,12 @@ def generate_token_bitmask(
|
|||||||
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
tree_traverse_time = time.perf_counter() - s
|
||||||
|
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
||||||
|
logger.warning(
|
||||||
|
f"Bit mask generation took {tree_traverse_time} seconds with "
|
||||||
|
f"grammar: {req.grammar}"
|
||||||
|
)
|
||||||
|
|
||||||
verify_input.grammar = grammar
|
verify_input.grammar = grammar
|
||||||
return allocate_token_bitmask
|
return allocate_token_bitmask
|
||||||
|
|||||||
@@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import (
|
|||||||
EagleVerifyInput,
|
EagleVerifyInput,
|
||||||
EagleVerifyOutput,
|
EagleVerifyOutput,
|
||||||
assign_draft_cache_locs,
|
assign_draft_cache_locs,
|
||||||
|
fast_topk,
|
||||||
generate_token_bitmask,
|
generate_token_bitmask,
|
||||||
select_top_k_tokens,
|
select_top_k_tokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
|
from sglang.srt.utils import (
|
||||||
|
empty_context,
|
||||||
|
get_available_gpu_memory,
|
||||||
|
is_cuda,
|
||||||
|
next_power_of_2,
|
||||||
|
)
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
from sgl_kernel import segment_packbits
|
from sgl_kernel import segment_packbits
|
||||||
@@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.init_attention_backend()
|
self.init_attention_backend()
|
||||||
self.init_cuda_graphs()
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
|
# Some dummy tensors
|
||||||
|
self.num_new_pages_per_topk = torch.empty(
|
||||||
|
(), dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
def init_attention_backend(self):
|
def init_attention_backend(self):
|
||||||
# Create multi-step attn backends and cuda graph runners
|
# Create multi-step attn backends and cuda graph runners
|
||||||
if self.server_args.attention_backend == "flashinfer":
|
if self.server_args.attention_backend == "flashinfer":
|
||||||
@@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Capture extend
|
# Capture extend
|
||||||
@@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -290,7 +302,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
A tuple of the final logit output of the target model, next tokens accepted,
|
A tuple of the final logit output of the target model, next tokens accepted,
|
||||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
spec_info = self.draft(batch)
|
spec_info = self.draft(batch)
|
||||||
@@ -366,14 +377,21 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Allocate cache locations
|
# Allocate cache locations
|
||||||
|
# Layout of the out_cache_loc
|
||||||
|
# [ topk 0 ] [ topk 1 ]
|
||||||
|
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
||||||
if self.page_size == 1:
|
if self.page_size == 1:
|
||||||
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
||||||
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.topk == 1:
|
if self.topk == 1:
|
||||||
prefix_lens = batch.seq_lens
|
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
|
||||||
seq_lens = prefix_lens + self.speculative_num_steps
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.seq_lens,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
extend_num_tokens = num_seqs * self.speculative_num_steps
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||||
else:
|
else:
|
||||||
# In this case, the last partial page needs to be duplicated.
|
# In this case, the last partial page needs to be duplicated.
|
||||||
@@ -386,29 +404,33 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# "x" means speculative draft tokens
|
# "x" means speculative draft tokens
|
||||||
# "." means padded tokens
|
# "." means padded tokens
|
||||||
|
|
||||||
# TODO: fuse these ops
|
# TODO(lmzheng): The current implementation is still a fake support
|
||||||
prefix_lens = batch.seq_lens
|
# for page size > 1. In the `assign_draft_cache_locs` below,
|
||||||
last_page_lens = prefix_lens % self.page_size
|
# we directly move the indices instead of the real kv cache.
|
||||||
num_new_pages = (
|
# This only works when the kernel backend runs with page size = 1.
|
||||||
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
# If the kernel backend runs with page size > 1, we need to
|
||||||
) // self.page_size
|
# duplicate the real KV cache. The overhead of duplicating KV
|
||||||
seq_lens = (
|
# cache seems okay because the draft KV cache only has one layer.
|
||||||
prefix_lens // self.page_size * self.page_size
|
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
|
||||||
+ num_new_pages * (self.page_size * self.topk)
|
|
||||||
)
|
(
|
||||||
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
prefix_lens,
|
||||||
raise NotImplementedError(
|
seq_lens,
|
||||||
"page_size > 1 and top_k > 1 are not supported."
|
last_loc,
|
||||||
)
|
self.num_new_pages_per_topk,
|
||||||
# TODO: Support page_size > 1 and top_k > 1
|
self.extend_lens,
|
||||||
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
) = get_last_loc_large_page_size_large_top_k(
|
||||||
# 2. Modify generate_draft_decode_kv_indices accordingly
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.seq_lens,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
self.topk,
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(lmzheng): remove this device sync
|
||||||
|
extend_num_tokens = torch.sum(self.extend_lens).item()
|
||||||
|
|
||||||
last_loc = get_last_loc(
|
|
||||||
batch.req_to_token_pool.req_to_token,
|
|
||||||
batch.req_pool_indices,
|
|
||||||
prefix_lens,
|
|
||||||
)
|
|
||||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||||
batch.alloc_paged_token_slots_extend(
|
batch.alloc_paged_token_slots_extend(
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
@@ -423,19 +445,31 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
batch.seq_lens,
|
batch.seq_lens,
|
||||||
|
self.extend_lens,
|
||||||
|
self.num_new_pages_per_topk,
|
||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
|
next_power_of_2(num_seqs),
|
||||||
|
next_power_of_2(self.speculative_num_steps),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.page_size > 1 and self.topk > 1:
|
||||||
|
# Remove padded slots
|
||||||
|
out_cache_loc = out_cache_loc[
|
||||||
|
: num_seqs * self.topk * self.speculative_num_steps
|
||||||
|
]
|
||||||
|
|
||||||
batch.out_cache_loc = out_cache_loc
|
batch.out_cache_loc = out_cache_loc
|
||||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||||
|
batch.return_hidden_states = False
|
||||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
batch.return_hidden_states = False
|
|
||||||
|
# Get forward batch
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -449,9 +483,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
# Initialize attention backend
|
# Initialize attention backend
|
||||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||||
forward_batch = ForwardBatch.init_new(
|
|
||||||
model_worker_batch, self.draft_model_runner
|
|
||||||
)
|
|
||||||
# Run forward steps
|
# Run forward steps
|
||||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||||
|
|
||||||
@@ -504,6 +535,13 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
if self.hot_token_id is not None:
|
if self.hot_token_id is not None:
|
||||||
topk_index = self.hot_token_id[topk_index]
|
topk_index = self.hot_token_id[topk_index]
|
||||||
|
|
||||||
|
out_cache_loc = out_cache_loc.reshape(
|
||||||
|
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
||||||
|
self.speculative_num_steps, -1
|
||||||
|
)
|
||||||
|
|
||||||
# Return values
|
# Return values
|
||||||
score_list: List[torch.Tensor] = []
|
score_list: List[torch.Tensor] = []
|
||||||
token_list: List[torch.Tensor] = []
|
token_list: List[torch.Tensor] = []
|
||||||
@@ -525,10 +563,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Set inputs
|
# Set inputs
|
||||||
forward_batch.input_ids = input_ids
|
forward_batch.input_ids = input_ids
|
||||||
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
||||||
forward_batch.out_cache_loc = out_cache_loc[
|
|
||||||
:, self.topk * i : self.topk * (i + 1)
|
|
||||||
].flatten()
|
|
||||||
forward_batch.positions.add_(1)
|
forward_batch.positions.add_(1)
|
||||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||||
spec_info.hidden_states = hidden_states
|
spec_info.hidden_states = hidden_states
|
||||||
@@ -586,7 +621,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
if vocab_mask is not None:
|
if vocab_mask is not None:
|
||||||
assert spec_info.grammar is not None
|
assert spec_info.grammar is not None
|
||||||
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
||||||
# otherwise, this vocab mask will be the one from the previous extend stage
|
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
|
||||||
# and will be applied to produce wrong results
|
# and will be applied to produce wrong results
|
||||||
batch.sampling_info.vocab_mask = None
|
batch.sampling_info.vocab_mask = None
|
||||||
|
|
||||||
@@ -607,13 +642,13 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
]
|
]
|
||||||
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
||||||
|
|
||||||
|
if batch.return_logprob:
|
||||||
|
self.add_logprob_values(batch, res, logits_output)
|
||||||
|
|
||||||
# Prepare the batch for the next draft forwards.
|
# Prepare the batch for the next draft forwards.
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
batch.spec_info = res.draft_input
|
batch.spec_info = res.draft_input
|
||||||
|
|
||||||
if batch.return_logprob:
|
|
||||||
self.add_logprob_values(batch, res, logits_output)
|
|
||||||
|
|
||||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||||
|
|
||||||
def add_logprob_values(
|
def add_logprob_values(
|
||||||
@@ -626,8 +661,16 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
logits_output = res.logits_output
|
logits_output = res.logits_output
|
||||||
top_logprobs_nums = batch.top_logprobs_nums
|
top_logprobs_nums = batch.top_logprobs_nums
|
||||||
token_ids_logprobs = batch.token_ids_logprobs
|
token_ids_logprobs = batch.token_ids_logprobs
|
||||||
|
accepted_indices = res.accepted_indices
|
||||||
|
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
||||||
|
temperatures = batch.sampling_info.temperatures
|
||||||
|
num_draft_tokens = batch.spec_info.draft_token_num
|
||||||
|
# acceptance indices are the indices in a "flattened" batch.
|
||||||
|
# dividing it to num_draft_tokens will yield the actual batch index.
|
||||||
|
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
||||||
|
|
||||||
logprobs = torch.nn.functional.log_softmax(
|
logprobs = torch.nn.functional.log_softmax(
|
||||||
logits_output.next_token_logits, dim=-1
|
logits_output.next_token_logits / temperatures, dim=-1
|
||||||
)
|
)
|
||||||
batch_next_token_ids = res.verified_id
|
batch_next_token_ids = res.verified_id
|
||||||
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||||
@@ -662,7 +705,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
pt = 0
|
pt = 0
|
||||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||||
verified_ids = batch_next_token_ids.tolist()
|
verified_ids = batch_next_token_ids.tolist()
|
||||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
|
||||||
for _ in range(num_tokens):
|
for _ in range(num_tokens):
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||||
@@ -690,7 +733,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
hidden_states: Hidden states from the target model forward
|
hidden_states: Hidden states from the target model forward
|
||||||
next_token_ids: Next token ids generated from the target forward.
|
next_token_ids: Next token ids generated from the target forward.
|
||||||
"""
|
"""
|
||||||
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
|
|
||||||
batch.spec_info = EagleDraftInput(
|
batch.spec_info = EagleDraftInput(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
verified_id=next_token_ids,
|
verified_id=next_token_ids,
|
||||||
@@ -701,7 +743,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
model_worker_batch = batch.get_model_worker_batch(
|
model_worker_batch = batch.get_model_worker_batch(
|
||||||
seq_lens_cpu_cache=seq_lens_cpu
|
seq_lens_cpu_cache=seq_lens_cpu
|
||||||
)
|
)
|
||||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -724,9 +765,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch,
|
batch,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
batch.return_hidden_states = False
|
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -790,3 +829,47 @@ def load_token_map(token_map_path: str) -> List[int]:
|
|||||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||||
return torch.tensor(hot_token_id, dtype=torch.int32)
|
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def get_last_loc_large_page_size_top_k_1(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens,
|
||||||
|
speculative_num_steps: int,
|
||||||
|
):
|
||||||
|
prefix_lens = seq_lens
|
||||||
|
seq_lens = prefix_lens + speculative_num_steps
|
||||||
|
last_loc = get_last_loc(
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
prefix_lens,
|
||||||
|
)
|
||||||
|
return prefix_lens, seq_lens, last_loc
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def get_last_loc_large_page_size_large_top_k(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
speculative_num_steps: int,
|
||||||
|
topk: int,
|
||||||
|
page_size: int,
|
||||||
|
):
|
||||||
|
prefix_lens = seq_lens
|
||||||
|
last_page_lens = prefix_lens % page_size
|
||||||
|
num_new_pages_per_topk = (
|
||||||
|
last_page_lens + speculative_num_steps + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
|
||||||
|
page_size * topk
|
||||||
|
)
|
||||||
|
extend_lens = seq_lens - prefix_lens
|
||||||
|
last_loc = get_last_loc(
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
prefix_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
|
||||||
|
|||||||
@@ -441,5 +441,71 @@ class TestEAGLEServerTriton(TestEAGLEServer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEAGLEServerPageSize(TestEAGLEServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
5,
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
1,
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
6,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
128,
|
||||||
|
"--max-running-requests",
|
||||||
|
8,
|
||||||
|
"--page-size",
|
||||||
|
4,
|
||||||
|
"--attention-backend",
|
||||||
|
"flashinfer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEAGLEServerPageSizeTopk(TestEAGLEServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
5,
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
8,
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
64,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
128,
|
||||||
|
"--max-running-requests",
|
||||||
|
8,
|
||||||
|
"--page-size",
|
||||||
|
4,
|
||||||
|
"--attention-backend",
|
||||||
|
"flashinfer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user