Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
515 lines
18 KiB
Python
515 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, List, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
from sglang.srt.managers.scheduler import global_server_args_dict
|
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
CaptureHiddenMode,
|
|
ForwardBatch,
|
|
ForwardMode,
|
|
)
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
|
|
from sglang.srt.speculative.spec_utils import (
|
|
SIMULATE_ACC_LEN,
|
|
generate_simulated_accept_index,
|
|
)
|
|
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|
EAGLEDraftCudaGraphRunner,
|
|
)
|
|
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
|
|
|
if is_cuda():
|
|
from sgl_kernel import (
|
|
top_k_renorm_prob,
|
|
top_p_renorm_prob,
|
|
tree_speculative_sampling_target_only,
|
|
verify_tree_greedy,
|
|
)
|
|
from sgl_kernel.top_k import fast_topk
|
|
elif is_hip():
|
|
from sgl_kernel import verify_tree_greedy
|
|
|
|
|
|
@triton.jit
|
|
def assign_draft_cache_locs_page_size_1(
|
|
req_pool_indices,
|
|
req_to_token,
|
|
seq_lens,
|
|
out_cache_loc,
|
|
pool_len: tl.constexpr,
|
|
topk: tl.constexpr,
|
|
speculative_num_steps: tl.constexpr,
|
|
):
|
|
BLOCK_SIZE: tl.constexpr = 128
|
|
pid = tl.program_id(axis=0)
|
|
|
|
copy_len = topk * speculative_num_steps
|
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
|
|
|
# Copy from req_to_token to out_cache_loc
|
|
kv_start = tl.load(seq_lens + pid)
|
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
|
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
|
for i in range(num_loop):
|
|
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
mask = copy_offset < copy_len
|
|
data = tl.load(token_pool + kv_start + copy_offset, mask=mask)
|
|
tl.store(out_cache_ptr + copy_offset, data, mask=mask)
|
|
|
|
|
|
@dataclass
|
|
class EagleDraftInputV2Mixin:
|
|
def prepare_for_v2_draft(
|
|
self: EagleDraftInput,
|
|
req_to_token_pool: ReqToTokenPool,
|
|
batch: ModelWorkerBatch,
|
|
cuda_graph_runner: EAGLEDraftCudaGraphRunner,
|
|
draft_model_runner: ModelRunner,
|
|
topk: int,
|
|
num_steps: int,
|
|
):
|
|
bs = len(batch.seq_lens)
|
|
|
|
# Assign cache locations
|
|
batch.out_cache_loc = torch.empty(
|
|
(bs * topk * num_steps,),
|
|
dtype=torch.int64,
|
|
device=batch.input_ids.device,
|
|
)
|
|
# FIXME(lsyin): align with the default code path
|
|
assign_draft_cache_locs_page_size_1[(bs,)](
|
|
batch.req_pool_indices,
|
|
req_to_token_pool.req_to_token,
|
|
batch.seq_lens,
|
|
batch.out_cache_loc,
|
|
req_to_token_pool.req_to_token.shape[1],
|
|
topk,
|
|
num_steps,
|
|
)
|
|
|
|
# Get a forward batch
|
|
batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
|
self.positions = batch.seq_lens.repeat_interleave(topk, dim=0)
|
|
forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
|
|
can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch)
|
|
return forward_batch, can_cuda_graph
|
|
|
|
def prepare_for_extend_to_fill_draft_kvcache(
|
|
self,
|
|
batch: ModelWorkerBatch,
|
|
predict: torch.Tensor,
|
|
num_draft_tokens: int,
|
|
draft_model_runner: Any,
|
|
):
|
|
seq_lens_cpu_backup = batch.seq_lens_cpu
|
|
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
|
|
|
|
batch.spec_info = self
|
|
batch.input_ids = predict
|
|
batch.seq_lens = batch.seq_lens + num_draft_tokens
|
|
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
|
|
batch.seq_lens_sum += extend_num_tokens
|
|
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
|
|
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist()
|
|
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
|
|
batch.extend_num_tokens = extend_num_tokens
|
|
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
|
|
forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
|
|
draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
|
return forward_batch
|
|
|
|
|
|
@dataclass
|
|
class EagleVerifyInputV2Mixin:
|
|
def prepare_for_v2_verify(
|
|
self: EagleVerifyInput,
|
|
req_to_token_pool: ReqToTokenPool,
|
|
batch: ModelWorkerBatch,
|
|
target_worker: TpModelWorker,
|
|
):
|
|
# Assign cache locations
|
|
bs = len(batch.req_pool_indices)
|
|
batch.input_ids = self.draft_token
|
|
device = batch.input_ids.device
|
|
batch.out_cache_loc = torch.empty(
|
|
(bs * self.draft_token_num,),
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
|
|
assign_extend_cache_locs[(bs,)](
|
|
batch.req_pool_indices,
|
|
req_to_token_pool.req_to_token,
|
|
batch.seq_lens,
|
|
batch.seq_lens + self.draft_token_num,
|
|
batch.out_cache_loc,
|
|
req_to_token_pool.req_to_token.shape[1],
|
|
next_power_of_2(bs),
|
|
)
|
|
|
|
# Get a forward batch
|
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
|
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner)
|
|
|
|
# Run attention backend plan and cuda graph preparation
|
|
can_run_cuda_graph = bool(
|
|
target_worker.model_runner.graph_runner
|
|
and target_worker.model_runner.graph_runner.can_run(verify_forward_batch)
|
|
)
|
|
if can_run_cuda_graph:
|
|
target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch)
|
|
else:
|
|
target_worker.model_runner.attn_backend.init_forward_metadata(
|
|
verify_forward_batch
|
|
)
|
|
|
|
return verify_forward_batch, can_run_cuda_graph
|
|
|
|
def sample(
|
|
self: EagleVerifyInput,
|
|
batch: ModelWorkerBatch,
|
|
logits_output: LogitsProcessorOutput,
|
|
):
|
|
"""
|
|
Verify and find accepted tokens based on logits output and batch
|
|
(which contains spec decoding information).
|
|
"""
|
|
bs = len(batch.seq_lens)
|
|
sampling_info = batch.sampling_info
|
|
next_token_logits = logits_output.next_token_logits
|
|
device = batch.input_ids.device
|
|
|
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
|
predict = torch.zeros(
|
|
(bs * (self.spec_steps + 1),), dtype=torch.int32, device=device
|
|
)
|
|
accept_index = torch.full(
|
|
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device
|
|
)
|
|
accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
|
|
|
|
# Sample tokens
|
|
if sampling_info.is_all_greedy:
|
|
target_predict = torch.argmax(next_token_logits, dim=-1)
|
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
|
|
|
verify_tree_greedy(
|
|
predicts=predict, # mutable
|
|
accept_index=accept_index, # mutable
|
|
accept_token_num=accept_length, # mutable
|
|
candidates=candidates,
|
|
retrive_index=self.retrive_index,
|
|
retrive_next_token=self.retrive_next_token,
|
|
retrive_next_sibling=self.retrive_next_sibling,
|
|
target_predict=target_predict,
|
|
)
|
|
else:
|
|
# Apply temperature and get target probs
|
|
expanded_temperature = torch.repeat_interleave(
|
|
sampling_info.temperatures, self.draft_token_num, dim=0
|
|
) # (bs * num_draft_tokens, 1)
|
|
|
|
target_probs = F.softmax(
|
|
next_token_logits / expanded_temperature, dim=-1
|
|
) # (bs * num_draft_tokens, vocab_size)
|
|
target_probs = top_k_renorm_prob(
|
|
target_probs,
|
|
torch.repeat_interleave(
|
|
sampling_info.top_ks, self.draft_token_num, dim=0
|
|
),
|
|
) # (bs * num_draft_tokens, vocab_size)
|
|
target_probs = top_p_renorm_prob(
|
|
target_probs,
|
|
torch.repeat_interleave(
|
|
sampling_info.top_ps, self.draft_token_num, dim=0
|
|
),
|
|
)
|
|
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
|
|
|
# This is currently not used
|
|
draft_probs = torch.empty_like(target_probs)
|
|
|
|
# coins for rejection sampling
|
|
coins = torch.rand_like(candidates, dtype=torch.float32, device=device)
|
|
# coins for final sampling
|
|
coins_for_final_sampling = torch.rand(
|
|
(bs,), dtype=torch.float32, device=device
|
|
)
|
|
|
|
tree_speculative_sampling_target_only(
|
|
predicts=predict, # mutable
|
|
accept_index=accept_index, # mutable
|
|
accept_token_num=accept_length, # mutable
|
|
candidates=candidates,
|
|
retrive_index=self.retrive_index,
|
|
retrive_next_token=self.retrive_next_token,
|
|
retrive_next_sibling=self.retrive_next_sibling,
|
|
uniform_samples=coins,
|
|
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
|
target_probs=target_probs,
|
|
draft_probs=draft_probs,
|
|
threshold_single=global_server_args_dict[
|
|
"speculative_accept_threshold_single"
|
|
],
|
|
threshold_acc=global_server_args_dict[
|
|
"speculative_accept_threshold_acc"
|
|
],
|
|
deterministic=True,
|
|
)
|
|
|
|
if SIMULATE_ACC_LEN > 0:
|
|
# Do simulation
|
|
accept_index = generate_simulated_accept_index(
|
|
accept_index=accept_index,
|
|
predict=predict, # mutable
|
|
accept_length=accept_length, # mutable
|
|
simulate_acc_len=SIMULATE_ACC_LEN,
|
|
bs=bs,
|
|
spec_steps=self.draft_token_num,
|
|
)
|
|
|
|
# Include the bonus token
|
|
accept_length.add_(1)
|
|
return predict, accept_length, accept_index
|
|
|
|
|
|
def build_tree_kernel_efficient_tmp(
|
|
verified_id: torch.Tensor,
|
|
parent_list: List[torch.Tensor],
|
|
top_scores_index: torch.Tensor,
|
|
draft_tokens: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
topk: int,
|
|
spec_steps: int,
|
|
num_verify_tokens: int,
|
|
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
|
tree_mask_buf: Optional[torch.Tensor] = None,
|
|
position_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
# TODO(lsyin): make it compatible with default code path
|
|
# TODO(lsyin): support cuda graph graph padding for eagle
|
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
|
|
|
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
|
bs = seq_lens.numel()
|
|
device = seq_lens.device
|
|
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
|
# where each row indicates the attending pattern of each draft token
|
|
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
|
if tree_mask_buf is not None:
|
|
tree_mask = tree_mask_buf
|
|
if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
|
tree_mask.fill_(True)
|
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
|
tree_mask.fill_(0)
|
|
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
|
tree_mask.fill_(True)
|
|
else:
|
|
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
|
tree_mask = torch.full(
|
|
(num_verify_tokens * bs * num_verify_tokens,),
|
|
True,
|
|
dtype=torch.bool,
|
|
device=device,
|
|
)
|
|
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
|
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
|
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
|
tree_mask = torch.zeros(
|
|
(num_verify_tokens * bs,),
|
|
dtype=packed_dtypes[packed_dtype_idx],
|
|
device=device,
|
|
)
|
|
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
|
tree_mask = torch.full(
|
|
(
|
|
seq_lens_sum * num_verify_tokens
|
|
+ num_verify_tokens * num_verify_tokens * bs,
|
|
),
|
|
True,
|
|
device=device,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
|
|
|
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
|
retrive_buf = torch.full(
|
|
(3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
|
)
|
|
retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
|
|
# position: where each token belongs to
|
|
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
|
# then, positions = [7, 8, 8, 9]
|
|
if position_buf is not None:
|
|
positions = position_buf
|
|
else:
|
|
positions = torch.empty(
|
|
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
|
)
|
|
|
|
from sgl_kernel import (
|
|
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
|
)
|
|
|
|
sgl_build_tree_kernel_efficient(
|
|
parent_list,
|
|
top_scores_index,
|
|
seq_lens,
|
|
tree_mask,
|
|
positions,
|
|
retrive_index,
|
|
retrive_next_token,
|
|
retrive_next_sibling,
|
|
topk,
|
|
spec_steps,
|
|
num_verify_tokens,
|
|
tree_mask_mode,
|
|
)
|
|
return (
|
|
tree_mask,
|
|
positions,
|
|
retrive_index,
|
|
retrive_next_token,
|
|
retrive_next_sibling,
|
|
draft_tokens,
|
|
)
|
|
|
|
|
|
@torch.compile(dynamic=True)
|
|
def select_top_k_tokens_tmp(
|
|
i: int,
|
|
topk_p: torch.Tensor,
|
|
topk_index: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
scores: torch.Tensor,
|
|
topk: int,
|
|
):
|
|
# FIXME(lsyin): remove this duplicate code
|
|
if i == 0:
|
|
# The first step after extend
|
|
input_ids = topk_index.flatten()
|
|
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
|
|
scores = topk_p # shape: (b, topk)
|
|
|
|
tree_info = (
|
|
topk_p.unsqueeze(1), # shape: (b, 1, topk)
|
|
topk_index, # shape: (b, topk)
|
|
torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
|
|
.unsqueeze(0)
|
|
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
|
)
|
|
else:
|
|
# The later decode steps
|
|
expand_scores = torch.mul(
|
|
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
|
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
|
topk_cs_p, topk_cs_index = fast_topk(
|
|
expand_scores.flatten(start_dim=1), topk, dim=-1
|
|
) # (b, topk)
|
|
scores = topk_cs_p # shape: (b, topk)
|
|
|
|
topk_index = topk_index.reshape(-1, topk**2)
|
|
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
|
|
|
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
|
0, hidden_states.shape[0], step=topk, device=hidden_states.device
|
|
).repeat_interleave(topk)
|
|
hidden_states = hidden_states[selected_input_index, :]
|
|
|
|
tree_info = (
|
|
expand_scores, # shape: (b, topk, topk)
|
|
topk_index, # shape: (b, topk * topk)
|
|
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
|
|
)
|
|
|
|
return input_ids, hidden_states, scores, tree_info
|
|
|
|
|
|
@triton.jit
|
|
def fill_new_verified_id(
|
|
verified_id,
|
|
accept_lens,
|
|
new_verified_id,
|
|
num_draft_tokens: tl.constexpr,
|
|
):
|
|
# NOTE: we cannot fuse any in-place operations of `accept_lens` inside this kernel
|
|
# because this kernel reads accept_lens
|
|
pid = tl.program_id(axis=0)
|
|
accept_length = tl.load(accept_lens + pid)
|
|
|
|
verified_id_idx = num_draft_tokens * pid + accept_length - 1
|
|
verified_id_data = tl.load(verified_id + verified_id_idx)
|
|
tl.store(new_verified_id + pid, verified_id_data)
|
|
|
|
|
|
@triton.jit
|
|
def fill_accepted_out_cache_loc(
|
|
accept_index,
|
|
out_cache_loc,
|
|
accepted_out_cache_loc,
|
|
size_upper: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offset = tl.arange(0, size_upper)
|
|
|
|
masks = (tl.load(accept_index + offset, offset < pid, other=-1) != -1).to(tl.int64)
|
|
dst = tl.sum(masks)
|
|
src = tl.load(accept_index + pid)
|
|
if src > -1:
|
|
value = tl.load(out_cache_loc + src)
|
|
tl.store(accepted_out_cache_loc + dst, value)
|
|
|
|
|
|
@triton.jit
|
|
def assign_extend_cache_locs(
|
|
req_pool_indices,
|
|
req_to_token,
|
|
start_offset,
|
|
end_offset,
|
|
out_cache_loc,
|
|
pool_len: tl.constexpr,
|
|
bs_upper: tl.constexpr,
|
|
):
|
|
BLOCK_SIZE: tl.constexpr = 32
|
|
pid = tl.program_id(axis=0)
|
|
kv_start = tl.load(start_offset + pid)
|
|
kv_end = tl.load(end_offset + pid)
|
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
|
|
|
length_offset = tl.arange(0, bs_upper)
|
|
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
|
|
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
|
|
out_offset = tl.sum(end - start, axis=0)
|
|
|
|
out_cache_ptr = out_cache_loc + out_offset
|
|
|
|
load_offset = tl.arange(0, BLOCK_SIZE) + kv_start
|
|
save_offset = tl.arange(0, BLOCK_SIZE)
|
|
|
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
|
for _ in range(num_loop):
|
|
mask = load_offset < kv_end
|
|
data = tl.load(token_pool + load_offset, mask=mask)
|
|
tl.store(out_cache_ptr + save_offset, data, mask=mask)
|
|
load_offset += BLOCK_SIZE
|
|
save_offset += BLOCK_SIZE
|