781 lines
30 KiB
Python
781 lines
30 KiB
Python
import logging
|
|
from copy import copy
|
|
from dataclasses import dataclass
|
|
from typing import ClassVar, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
|
from sglang.srt.managers.overlap_utils import FutureIndices
|
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
from sglang.srt.mem_cache.common import (
|
|
alloc_paged_token_slots_extend,
|
|
alloc_token_slots,
|
|
get_last_loc,
|
|
)
|
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
|
from sglang.srt.server_args import get_global_server_args
|
|
from sglang.srt.speculative.eagle_info_v2 import (
|
|
EagleDraftInputV2Mixin,
|
|
EagleVerifyInputV2Mixin,
|
|
)
|
|
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
|
from sglang.srt.speculative.spec_utils import (
|
|
SIMULATE_ACC_LEN,
|
|
TREE_SPEC_KERNEL_AVAILABLE,
|
|
align_evict_mask_to_page_size,
|
|
assign_req_to_token_pool,
|
|
create_accept_length_filter,
|
|
create_extend_after_decode_spec_info,
|
|
filter_finished_cache_loc_kernel,
|
|
generate_simulated_accept_index,
|
|
get_src_tgt_cache_loc,
|
|
get_target_cache_loc,
|
|
)
|
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
|
|
|
if is_cuda():
|
|
from sgl_kernel import (
|
|
top_k_renorm_prob,
|
|
top_p_renorm_prob,
|
|
tree_speculative_sampling_target_only,
|
|
verify_tree_greedy,
|
|
)
|
|
elif is_hip():
|
|
from sgl_kernel import verify_tree_greedy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
draft_token: torch.Tensor
|
|
custom_mask: torch.Tensor
|
|
positions: torch.Tensor
|
|
retrive_index: torch.Tensor
|
|
retrive_next_token: torch.Tensor
|
|
retrive_next_sibling: torch.Tensor
|
|
retrive_cum_len: torch.Tensor
|
|
spec_steps: int
|
|
topk: int
|
|
draft_token_num: int
|
|
capture_hidden_mode: CaptureHiddenMode
|
|
seq_lens_sum: int
|
|
seq_lens_cpu: torch.Tensor
|
|
grammar: BaseGrammarObject = None
|
|
|
|
def __post_init__(self):
|
|
super().__init__(SpecInputType.EAGLE_VERIFY)
|
|
|
|
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
|
return self.draft_token_num, self.draft_token_num
|
|
|
|
@classmethod
|
|
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
|
return cls(
|
|
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
|
|
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
|
|
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
|
|
retrive_index=torch.full(
|
|
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
|
),
|
|
retrive_next_token=torch.full(
|
|
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
|
),
|
|
retrive_next_sibling=torch.full(
|
|
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
|
),
|
|
retrive_cum_len=None,
|
|
topk=topk,
|
|
draft_token_num=num_verify_tokens,
|
|
spec_steps=spec_steps,
|
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
|
seq_lens_sum=0,
|
|
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
|
|
)
|
|
|
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
|
|
|
if batch.forward_mode.is_idle():
|
|
return
|
|
|
|
batch.input_ids = self.draft_token
|
|
|
|
if page_size == 1:
|
|
batch.out_cache_loc = alloc_token_slots(
|
|
batch.tree_cache,
|
|
len(batch.input_ids),
|
|
)
|
|
end_offset = batch.seq_lens + self.draft_token_num
|
|
else:
|
|
prefix_lens = batch.seq_lens
|
|
prefix_lens_cpu = batch.seq_lens_cpu
|
|
end_offset = prefix_lens + self.draft_token_num
|
|
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
|
|
last_loc = get_last_loc(
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.req_pool_indices,
|
|
prefix_lens,
|
|
)
|
|
batch.out_cache_loc = alloc_paged_token_slots_extend(
|
|
batch.tree_cache,
|
|
prefix_lens,
|
|
prefix_lens_cpu,
|
|
end_offset,
|
|
end_offset_cpu,
|
|
last_loc,
|
|
len(batch.input_ids),
|
|
)
|
|
self.last_loc = last_loc
|
|
|
|
bs = batch.batch_size()
|
|
assign_req_to_token_pool[(bs,)](
|
|
batch.req_pool_indices,
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.seq_lens,
|
|
end_offset,
|
|
batch.out_cache_loc,
|
|
batch.req_to_token_pool.req_to_token.shape[1],
|
|
next_power_of_2(bs),
|
|
)
|
|
|
|
def generate_attn_arg_prefill(
|
|
self,
|
|
req_pool_indices: torch.Tensor,
|
|
paged_kernel_lens: torch.Tensor,
|
|
paged_kernel_lens_sum: int,
|
|
req_to_token: torch.Tensor,
|
|
):
|
|
batch_size = len(req_pool_indices)
|
|
qo_indptr = torch.arange(
|
|
0,
|
|
(1 + batch_size) * self.draft_token_num,
|
|
step=self.draft_token_num,
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
cum_kv_seq_len = torch.zeros(
|
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
)
|
|
|
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
|
|
kv_indices = torch.empty(
|
|
paged_kernel_lens_sum + self.draft_token_num * batch_size,
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
|
req_to_token,
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
cum_kv_seq_len,
|
|
None,
|
|
kv_indices,
|
|
req_to_token.size(1),
|
|
)
|
|
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
|
|
|
def verify(
|
|
self,
|
|
batch: ScheduleBatch,
|
|
logits_output: LogitsProcessorOutput,
|
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
|
page_size: int,
|
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
|
) -> torch.Tensor:
|
|
"""
|
|
Verify and find accepted tokens based on logits output and batch
|
|
(which contains spec decoding information).
|
|
|
|
WARNING: This API in-place modifies the states of logits_output
|
|
|
|
This API updates values inside logits_output based on the accepted
|
|
tokens. I.e., logits_output.next_token_logits only contains
|
|
accepted token logits.
|
|
"""
|
|
if batch.forward_mode.is_idle():
|
|
return EagleVerifyOutput(
|
|
draft_input=EagleDraftInput.create_idle_input(
|
|
device=batch.device,
|
|
hidden_size=batch.model_config.hidden_size,
|
|
dtype=batch.model_config.dtype,
|
|
topk=self.topk,
|
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
|
),
|
|
logits_output=logits_output,
|
|
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
|
|
accept_length_per_req_cpu=[],
|
|
accepted_indices=torch.full(
|
|
(0, self.spec_steps + 1),
|
|
-1,
|
|
dtype=torch.int32,
|
|
device=batch.device,
|
|
),
|
|
)
|
|
|
|
bs = self.retrive_index.shape[0]
|
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
|
sampling_info = batch.sampling_info
|
|
|
|
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
|
predict_shape[-1] += 1
|
|
predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
|
|
accept_index = torch.full(
|
|
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
|
)
|
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
|
|
|
if bs != len(sampling_info):
|
|
sampling_info = copy.deepcopy(sampling_info)
|
|
# NOTE: retrive_index are the indices of the requests that are kept.
|
|
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
|
|
|
# Apply the custom logit processors if registered in the sampling info.
|
|
if sampling_info.has_custom_logit_processor:
|
|
apply_custom_logit_processor(
|
|
logits_output.next_token_logits,
|
|
sampling_info,
|
|
num_tokens_in_batch=self.draft_token_num,
|
|
)
|
|
|
|
# Apply penalty
|
|
if sampling_info.penalizer_orchestrator.is_required:
|
|
# This is a relaxed version of penalties for speculative decoding.
|
|
linear_penalty = torch.zeros(
|
|
(bs, logits_output.next_token_logits.shape[1]),
|
|
dtype=torch.float32,
|
|
device="cuda",
|
|
)
|
|
sampling_info.apply_logits_bias(linear_penalty)
|
|
logits_output.next_token_logits.add_(
|
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
|
)
|
|
|
|
# Apply grammar mask
|
|
if vocab_mask is not None:
|
|
assert self.grammar is not None
|
|
self.grammar.apply_vocab_mask(
|
|
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
|
)
|
|
|
|
# Sample tokens. Force greedy sampling on AMD
|
|
is_all_greedy = sampling_info.is_all_greedy
|
|
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
|
|
logger.warning(
|
|
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
|
|
"Falling back to greedy verification."
|
|
)
|
|
|
|
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
|
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
|
|
|
verify_tree_greedy(
|
|
predicts=predict, # mutable
|
|
accept_index=accept_index, # mutable
|
|
accept_token_num=accept_length, # mutable
|
|
candidates=candidates,
|
|
retrive_index=self.retrive_index,
|
|
retrive_next_token=self.retrive_next_token,
|
|
retrive_next_sibling=self.retrive_next_sibling,
|
|
target_predict=target_predict,
|
|
)
|
|
else:
|
|
# apply temperature and get target probs
|
|
expanded_temperature = torch.repeat_interleave(
|
|
sampling_info.temperatures, self.draft_token_num, dim=0
|
|
) # (bs * draft_token_num, 1)
|
|
|
|
target_probs = F.softmax(
|
|
logits_output.next_token_logits / expanded_temperature, dim=-1
|
|
) # (bs * draft_token_num, vocab_size)
|
|
target_probs = top_k_renorm_prob(
|
|
target_probs,
|
|
torch.repeat_interleave(
|
|
sampling_info.top_ks, self.draft_token_num, dim=0
|
|
),
|
|
) # (bs * draft_token_num, vocab_size)
|
|
if not torch.all(sampling_info.top_ps == 1.0):
|
|
target_probs = top_p_renorm_prob(
|
|
target_probs,
|
|
torch.repeat_interleave(
|
|
sampling_info.top_ps, self.draft_token_num, dim=0
|
|
),
|
|
)
|
|
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
|
|
|
draft_probs = torch.zeros(
|
|
target_probs.shape, dtype=torch.float32, device="cuda"
|
|
)
|
|
|
|
# coins for rejection sampling
|
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
|
# coins for final sampling
|
|
coins_for_final_sampling = torch.rand(
|
|
(bs,), dtype=torch.float32, device="cuda"
|
|
)
|
|
tree_speculative_sampling_target_only(
|
|
predicts=predict, # mutable
|
|
accept_index=accept_index, # mutable
|
|
accept_token_num=accept_length, # mutable
|
|
candidates=candidates,
|
|
retrive_index=self.retrive_index,
|
|
retrive_next_token=self.retrive_next_token,
|
|
retrive_next_sibling=self.retrive_next_sibling,
|
|
uniform_samples=coins,
|
|
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
|
target_probs=target_probs,
|
|
draft_probs=draft_probs,
|
|
threshold_single=get_global_server_args().speculative_accept_threshold_single,
|
|
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
|
|
deterministic=True,
|
|
)
|
|
|
|
if SIMULATE_ACC_LEN > 0.0:
|
|
# Do simulation
|
|
accept_index = generate_simulated_accept_index(
|
|
accept_index=accept_index,
|
|
predict=predict, # mutable
|
|
accept_length=accept_length, # mutable
|
|
bs=bs,
|
|
spec_steps=self.spec_steps,
|
|
)
|
|
|
|
unfinished_index = []
|
|
unfinished_accept_index = []
|
|
accept_index_cpu = accept_index.tolist()
|
|
predict_cpu = predict.tolist()
|
|
has_finished = False
|
|
|
|
# Iterate every accepted token and check if req has finished after append the token
|
|
# should be checked BEFORE free kv cache slots
|
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
|
for j, idx in enumerate(accept_index_row):
|
|
if idx == -1:
|
|
break
|
|
id = predict_cpu[idx]
|
|
req.output_ids.append(id)
|
|
req.check_finished()
|
|
if req.finished():
|
|
has_finished = True
|
|
# set all tokens after finished token to -1 and break
|
|
accept_index[i, j + 1 :] = -1
|
|
break
|
|
else:
|
|
if req.grammar is not None:
|
|
try:
|
|
req.grammar.accept_token(id)
|
|
except ValueError as e:
|
|
logger.info(
|
|
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
|
|
)
|
|
raise e
|
|
if not req.finished():
|
|
unfinished_index.append(i)
|
|
if idx == -1:
|
|
unfinished_accept_index.append(accept_index[i, :j])
|
|
else:
|
|
unfinished_accept_index.append(accept_index[i])
|
|
req.spec_verify_ct += 1
|
|
|
|
if has_finished:
|
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
|
|
|
# Free the KV cache for unaccepted tokens
|
|
# TODO: fuse them
|
|
accept_index = accept_index[accept_index != -1]
|
|
verified_id = predict[accept_index]
|
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
|
evict_mask[accept_index] = False
|
|
accept_length_cpu = accept_length.cpu()
|
|
# FIXME: this `tolist()` fixes the numerical calculation consistency
|
|
# try to unify the tensor representation and list representation
|
|
accept_length_list = accept_length_cpu.tolist()
|
|
|
|
if page_size == 1:
|
|
# TODO: boolean array index leads to a device sync. Remove it.
|
|
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
|
else:
|
|
if self.topk == 1:
|
|
# Only evict full empty page. Do not evict partial empty page
|
|
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
|
batch.seq_lens,
|
|
evict_mask,
|
|
page_size,
|
|
self.draft_token_num,
|
|
next_power_of_2(self.draft_token_num),
|
|
)
|
|
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
|
else:
|
|
# Shift the accepted tokens to the beginning.
|
|
# Only evict the last part
|
|
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
|
batch.seq_lens,
|
|
batch.out_cache_loc,
|
|
accept_index,
|
|
accept_length,
|
|
self.draft_token_num,
|
|
page_size,
|
|
)
|
|
to_free_slots = torch.empty(
|
|
(to_free_num_slots.sum().item(),),
|
|
dtype=torch.int64,
|
|
device=to_free_num_slots.device,
|
|
)
|
|
|
|
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
|
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
|
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
|
# to_free_slots: [ 2, 5, 7 8]
|
|
# to_free_slots also needs to be page-aligned without the first partial page
|
|
#
|
|
# split each row of out_cache_loc into two parts.
|
|
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
|
# 2. the second part goes to to_free_slots.
|
|
get_target_cache_loc[(bs,)](
|
|
tgt_cache_loc,
|
|
to_free_slots,
|
|
accept_length,
|
|
to_free_num_slots,
|
|
batch.out_cache_loc,
|
|
self.draft_token_num,
|
|
next_power_of_2(self.draft_token_num),
|
|
next_power_of_2(bs),
|
|
)
|
|
|
|
# Free the kv cache
|
|
token_to_kv_pool_allocator.free(to_free_slots)
|
|
|
|
# Copy the kv cache
|
|
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
|
tgt_cache_loc, src_cache_loc
|
|
)
|
|
|
|
# Construct EagleVerifyOutput
|
|
if not has_finished:
|
|
if page_size == 1 or self.topk == 1:
|
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
|
assign_req_to_token_pool[(bs,)](
|
|
batch.req_pool_indices,
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.seq_lens,
|
|
batch.seq_lens + accept_length + 1,
|
|
batch.out_cache_loc,
|
|
batch.req_to_token_pool.req_to_token.shape[1],
|
|
next_power_of_2(bs),
|
|
)
|
|
else:
|
|
batch.out_cache_loc = tgt_cache_loc
|
|
batch.seq_lens.add_(accept_length + 1)
|
|
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
|
|
|
draft_input = EagleDraftInput(
|
|
hidden_states=batch.spec_info.hidden_states[accept_index],
|
|
verified_id=verified_id,
|
|
accept_length=accept_length,
|
|
accept_length_cpu=accept_length_list,
|
|
seq_lens_for_draft_extend=batch.seq_lens,
|
|
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
|
|
req_pool_indices_for_draft_extend=batch.req_pool_indices,
|
|
)
|
|
|
|
return EagleVerifyOutput(
|
|
draft_input=draft_input,
|
|
logits_output=logits_output,
|
|
verified_id=verified_id,
|
|
accept_length_per_req_cpu=draft_input.accept_length_cpu,
|
|
accepted_indices=accept_index,
|
|
)
|
|
else:
|
|
if page_size == 1 or self.topk == 1:
|
|
assign_req_to_token_pool[(bs,)](
|
|
batch.req_pool_indices,
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.seq_lens,
|
|
batch.seq_lens + accept_length + 1,
|
|
batch.out_cache_loc[accept_index],
|
|
batch.req_to_token_pool.req_to_token.shape[1],
|
|
next_power_of_2(bs),
|
|
)
|
|
batch.seq_lens.add_(accept_length + 1)
|
|
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
|
|
|
if len(unfinished_accept_index) > 0:
|
|
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
|
unfinished_index_device = torch.tensor(
|
|
unfinished_index, dtype=torch.int64, device=predict.device
|
|
)
|
|
draft_input_accept_length_cpu = [
|
|
accept_length_list[i] for i in unfinished_index
|
|
]
|
|
if page_size == 1 or self.topk == 1:
|
|
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
|
else:
|
|
batch.out_cache_loc = torch.empty(
|
|
len(unfinished_index) + sum(draft_input_accept_length_cpu),
|
|
dtype=torch.int64,
|
|
device=predict.device,
|
|
)
|
|
accept_length_filter = create_accept_length_filter(
|
|
accept_length,
|
|
unfinished_index_device,
|
|
batch.seq_lens,
|
|
)
|
|
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
|
filter_finished_cache_loc_kernel[(bs,)](
|
|
batch.out_cache_loc,
|
|
tgt_cache_loc,
|
|
accept_length,
|
|
accept_length_filter,
|
|
next_power_of_2(bs),
|
|
next_power_of_2(self.draft_token_num),
|
|
)
|
|
|
|
draft_input = EagleDraftInput(
|
|
hidden_states=batch.spec_info.hidden_states[
|
|
unfinished_accept_index
|
|
],
|
|
verified_id=predict[unfinished_accept_index],
|
|
accept_length_cpu=draft_input_accept_length_cpu,
|
|
accept_length=accept_length[unfinished_index_device],
|
|
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
|
|
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
|
|
req_pool_indices_for_draft_extend=batch.req_pool_indices[
|
|
unfinished_index_device
|
|
],
|
|
)
|
|
else:
|
|
draft_input = EagleDraftInput.create_idle_input(
|
|
device=batch.device,
|
|
hidden_size=batch.model_config.hidden_size,
|
|
dtype=batch.model_config.dtype,
|
|
topk=self.topk,
|
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
|
)
|
|
|
|
return EagleVerifyOutput(
|
|
draft_input=draft_input,
|
|
logits_output=logits_output,
|
|
verified_id=verified_id,
|
|
accept_length_per_req_cpu=accept_length_list,
|
|
accepted_indices=accept_index,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
|
# The inputs for decode
|
|
# shape: (b, topk)
|
|
topk_p: torch.Tensor = None
|
|
topk_index: torch.Tensor = None
|
|
# shape: (b, hidden_size)
|
|
hidden_states: torch.Tensor = None
|
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
|
|
|
|
# Inputs for extend
|
|
# shape: (b,)
|
|
verified_id: torch.Tensor = None
|
|
accept_length: torch.Tensor = None
|
|
accept_length_cpu: List[int] = None
|
|
|
|
# Inputs for the attention backends
|
|
# shape: (b + 1,)
|
|
kv_indptr: torch.Tensor = None
|
|
kv_indices: torch.Tensor = None
|
|
|
|
# Shape info for padding
|
|
num_tokens_per_batch: int = -1
|
|
num_tokens_for_logprob_per_batch: int = -1
|
|
|
|
# Inputs for draft extend
|
|
# shape: (b,)
|
|
seq_lens_for_draft_extend: torch.Tensor = None
|
|
seq_lens_for_draft_extend_cpu: torch.Tensor = None
|
|
req_pool_indices_for_draft_extend: torch.Tensor = None
|
|
|
|
# Inputs for V2 overlap worker
|
|
future_indices: Optional[FutureIndices] = None
|
|
allocate_lens: Optional[torch.Tensor] = None
|
|
new_seq_lens: Optional[torch.Tensor] = None
|
|
verify_done: Optional[torch.cuda.Event] = None
|
|
|
|
# FIXME(lsyin): remove this hack
|
|
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
|
|
|
|
def __post_init__(self):
|
|
super().__init__(SpecInputType.EAGLE_DRAFT)
|
|
|
|
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
|
return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
|
|
|
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
|
|
|
if batch.forward_mode.is_idle():
|
|
return
|
|
|
|
# Prefill only generate 1 token.
|
|
assert len(self.verified_id) == len(batch.seq_lens)
|
|
|
|
pt = 0
|
|
for i, extend_len in enumerate(batch.extend_lens):
|
|
input_ids = batch.input_ids[pt : pt + extend_len]
|
|
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
|
)
|
|
pt += extend_len
|
|
|
|
@classmethod
|
|
def create_idle_input(
|
|
cls,
|
|
device: torch.device,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
topk: int,
|
|
capture_hidden_mode: CaptureHiddenMode,
|
|
):
|
|
return cls(
|
|
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
|
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
|
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
|
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
|
capture_hidden_mode=capture_hidden_mode,
|
|
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
|
|
accept_length_cpu=[],
|
|
)
|
|
|
|
def prepare_extend_after_decode(
|
|
self,
|
|
batch: ScheduleBatch,
|
|
speculative_num_steps: int,
|
|
):
|
|
|
|
if batch.forward_mode.is_idle():
|
|
return
|
|
|
|
batch.input_ids = self.verified_id
|
|
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
|
batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
|
|
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
|
batch.return_logprob = False
|
|
batch.return_hidden_states = False
|
|
|
|
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
|
self.accept_length.add_(1)
|
|
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
|
|
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
|
|
|
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
|
|
batch.input_ids,
|
|
batch.seq_lens,
|
|
self.accept_length,
|
|
self.positions,
|
|
self.verified_id,
|
|
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
|
|
)
|
|
|
|
def generate_attn_arg_prefill(
|
|
self,
|
|
req_pool_indices: torch.Tensor,
|
|
paged_kernel_lens: torch.Tensor,
|
|
paged_kernel_lens_sum: int,
|
|
req_to_token: torch.Tensor,
|
|
):
|
|
bs = self.accept_length.numel()
|
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
|
|
if paged_kernel_lens_sum is None:
|
|
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
|
|
|
kv_indices = torch.empty(
|
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
|
)
|
|
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
req_to_token,
|
|
req_pool_indices,
|
|
paged_kernel_lens,
|
|
cum_kv_seq_len,
|
|
None,
|
|
kv_indices,
|
|
req_to_token.size(1),
|
|
)
|
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
|
|
|
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
|
if self.future_indices is not None:
|
|
self.future_indices.indices = self.future_indices.indices[new_indices]
|
|
self.allocate_lens = self.allocate_lens[new_indices]
|
|
return
|
|
|
|
if has_been_filtered:
|
|
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
|
|
# therefore, we don't need to filter the batch again in scheduler
|
|
if len(new_indices) != len(self.topk_p):
|
|
logger.warning(
|
|
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
|
|
)
|
|
self.topk_p = self.topk_p[: len(new_indices)]
|
|
self.topk_index = self.topk_index[: len(new_indices)]
|
|
self.hidden_states = self.hidden_states[: len(new_indices)]
|
|
self.verified_id = self.verified_id[: len(new_indices)]
|
|
else:
|
|
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
|
|
self.topk_p = self.topk_p[new_indices]
|
|
self.topk_index = self.topk_index[new_indices]
|
|
self.hidden_states = self.hidden_states[new_indices]
|
|
self.verified_id = self.verified_id[new_indices]
|
|
|
|
def merge_batch(self, spec_info: "EagleDraftInput"):
|
|
if self.future_indices is not None:
|
|
assert spec_info.future_indices is not None
|
|
self.future_indices = FutureIndices(
|
|
indices=torch.cat(
|
|
[self.future_indices.indices, spec_info.future_indices.indices]
|
|
)
|
|
)
|
|
self.allocate_lens = torch.cat(
|
|
[self.allocate_lens, spec_info.allocate_lens]
|
|
)
|
|
return
|
|
|
|
if self.hidden_states is None:
|
|
self.hidden_states = spec_info.hidden_states
|
|
self.verified_id = spec_info.verified_id
|
|
self.topk_p = spec_info.topk_p
|
|
self.topk_index = spec_info.topk_index
|
|
return
|
|
if spec_info.hidden_states is None:
|
|
return
|
|
self.hidden_states = torch.cat(
|
|
[self.hidden_states, spec_info.hidden_states], axis=0
|
|
)
|
|
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
|
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
|
|
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
|
|
|
|
|
@dataclass
|
|
class EagleVerifyOutput:
|
|
# Draft input batch
|
|
draft_input: EagleDraftInput
|
|
# Logit outputs from target worker
|
|
logits_output: LogitsProcessorOutput
|
|
# Accepted token ids including the bonus token
|
|
verified_id: torch.Tensor
|
|
# Accepted token length per sequence in a batch in CPU.
|
|
accept_length_per_req_cpu: List[int]
|
|
# Accepted indices from logits_output.next_token_logits
|
|
accepted_indices: torch.Tensor
|