refactor EAGLE 2 (#3269)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: merrymercy <lianminzheng@gmail.com>
Co-authored-by: Ying1123 <sqy1415@gmail.com>
This commit is contained in:
Yineng Zhang
2025-02-03 20:52:30 +08:00
committed by GitHub
parent 3c8ac78dc1
commit 013021b6a1
9 changed files with 1271 additions and 687 deletions

View File

@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected
)
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
def build_tree_kernel(
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
):
bs = seq_lens.numel()
device = parent_list.device
tree_mask = torch.full(
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
(seq_lens_sum * draft_token + draft_token * draft_token * bs,),
True,
device=device,
)

View File

@@ -0,0 +1,213 @@
from __future__ import annotations
import bisect
import time
from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
server_args = model_runner.server_args
assert self.disable_padding
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = server_args.speculative_eagle_topk
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
if self.enable_torch_compile:
set_torch_compile_config()
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
return is_bs_supported
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, num_seqs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_tokens = num_seqs * self.num_tokens_per_bs
# Graph inputs
req_pool_indices = self.req_pool_indices[:num_seqs]
seq_lens = self.seq_lens[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens]
topk_p = self.topk_p[:num_seqs]
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
)
# Forward batch
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=num_seqs,
input_ids=None,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
)
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
forward_batch
)
# Run and capture
def run_once():
# Backup two fileds, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_forward(forward_batch)
forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
set_global_graph_memory_pool(graph.pool())
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_(
forward_batch.out_cache_loc
)
self.positions[:raw_num_token].copy_(forward_batch.positions)
self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch
)
# Replay
self.graphs[bs].replay()
return self.output_buffers[bs]

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, List
import torch
@@ -9,13 +10,360 @@ import triton.language as tl
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
from sglang.srt.speculative.spec_info import SpecInfo
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs
@dataclasses.dataclass
class EagleDraftInput:
# The inputs for decode
# shape: (b, topk)
topk_p: torch.Tensor = None
topk_index: torch.Tensor = None
# shape: (b, hidden_size)
hidden_states: torch.Tensor = None
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
# Inputs for extend
# shape: (b,)
verified_id: torch.Tensor = None
accept_length: torch.Tensor = None
accept_length_cpu: List[int] = None
# Inputs for the attention backends
# shape: (b + 1,)
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
batch.out_cache_loc = out_cache_loc
pt = 0
for i, req in enumerate(batch.reqs):
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
)
pt += req.extend_input_len
# TODO: support batching inputs
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist()
pt = 0
i = 0
for req in batch.reqs:
if req.finished():
continue
# assert seq_len - pre_len == req.extend_input_len
input_len = batch.extend_lens[i]
seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len
i += 1
assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
self.accept_length.add_(1)
create_extend_spec_info[(self.accept_length.numel(),)](
self.verified_id,
batch.seq_lens,
self.accept_length,
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
self.positions,
new_verified_id,
triton.next_power_of_2(speculative_num_steps + 1),
)
batch.seq_lens_sum = sum(seq_lens_cpu)
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor):
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
def merge_batch(self, spec_info: EagleDraftInput):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.topk_p = spec_info.topk_p
self.topk_index = spec_info.topk_index
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
@dataclasses.dataclass
class EagleVerifyInput:
draft_token: torch.Tensor
custom_mask: torch.Tensor
positions: torch.Tensor
retrive_index: torch.Tensor
retrive_cum_len: torch.Tensor
draft_token_num: int
capture_hidden_mode: CaptureHiddenMode
@classmethod
def create(
cls,
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_token: int,
):
score_list = torch.cat(score_list, dim=1).flatten(
1
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
ss_token_list = torch.cat(
token_list, dim=1
) # b, (self.topk + (num_steps-1) * self.topk)
top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
parent_list = torch.cat(parents_list[:-1], dim=1)
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
parent_list,
top_scores_index,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_token,
)
return cls(
draft_tokens.flatten(),
tree_mask,
position,
retrive_index,
retrive_cum_len,
num_verify_token,
CaptureHiddenMode.FULL,
)
def prepare_for_verify(self, batch: ScheduleBatch):
batch.input_ids = self.draft_token
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
bs = batch.batch_size()
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
req_to_token: torch.Tensor,
):
batch_size = len(req_pool_indices)
qo_indptr = torch.arange(
0,
(1 + batch_size) * self.draft_token_num,
step=self.draft_token_num,
dtype=torch.int32,
device="cuda",
)
cum_kv_seq_len = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
)
draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
dim=-1,
)
target_predict = predict[self.retrive_index]
candidates = draft_token[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
bs = self.retrive_cum_len.numel() - 1
max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
new_accept_index = []
unfinished_index = []
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
has_finished = False
# iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
new_accept_index_ = []
for j, idx in enumerate(accept_index_row):
if idx == -1:
break
id = predict_cpu[idx]
# if not found_finished:
req.output_ids.append(id)
finished_extend_len[req.rid] = j + 1
req.check_finished()
if req.finished():
has_finished = True
# set all tokens after finished token to -1 and break
accept_index[i, j + 1 :] = -1
break
else:
new_accept_index_.append(idx)
if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i)
req.spec_verify_ct += 1
accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
batch.token_to_kv_pool.free(mem_need_free_idx)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
draft_input = EagleDraftInput()
if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda")
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
draft_input.verified_id = predict[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return (
draft_input,
logits_output,
verified_id,
finished_extend_len,
accept_length_cpu,
)
@triton.jit
@@ -136,21 +484,57 @@ def assign_req_to_token_pool(
load_offset += BLOCK_SIZE
@triton.jit
def assign_draft_cache_locs(
req_pool_indices,
req_to_token,
seq_lens,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(seq_lens + pid)
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
for i in range(num_loop):
save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = save_offset < kv_end
data = tl.load(out_cache_ptr + load_offset, mask=mask)
tl.store(token_pool + save_offset, data, mask=mask)
@triton.jit
def generate_draft_decode_kv_indices(
req_pool_indices,
req_to_token,
paged_kernel_lens,
kv_indices,
iters: tl.constexpr,
kv_indptr,
positions,
num_seqs: tl.constexpr,
topk: tl.constexpr,
pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
bid = tl.program_id(axis=0)
topk_id = tl.program_id(axis=1)
iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2)
kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters
iters += 1
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
@@ -176,473 +560,73 @@ def generate_draft_decode_kv_indices(
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
# Update kv_indptr
bs_offset = tl.arange(0, num_tokens_upper)
class EAGLEDraftInput(SpecInfo):
def __init__(self):
self.prev_mode = ForwardMode.DECODE
zid = bid * topk + topk_id
if zid == 0:
zid = num_seqs * topk
positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
base = tl.sum(positions)
tl.store(kv_indptr + zid, base + zid * iters)
self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = []
self.token_list: List[torch.Tensor] = []
self.origin_score_list: List[torch.Tensor] = [] # used for sampling
self.parents_list: List[torch.Tensor] = []
self.cache_list: List[torch.Tenor] = []
self.iter = 0
# shape: (b, hidden_size)
self.hidden_states: torch.Tensor = None
# shape: (b,)
self.verified_id: torch.Tensor = None
# shape: (b, vocab_size)
self.sample_output: torch.Tensor = None
@torch.compile
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
if i == 0:
# The first step after extend
input_ids = topk_index.flatten()
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p # shape: (b, topk)
self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None
self.accept_length_cpu: List[int] = None
def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
batch.out_cache_loc = out_cache_loc
pt = 0
for i, req in enumerate(batch.reqs):
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
)
pt += req.extend_input_len
# TODO: support batching inputs
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def filter_batch(
self,
new_indices: torch.Tensor,
):
self.sample_output = self.sample_output[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1)
topk_index, topk_p = (
top.indices,
top.values,
) # shape: (b * top_k, top_k) or (b, top_k)
if self.prev_mode.is_decode():
scores = torch.mul(
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs = torch.topk(
scores.flatten(start_dim=1), self.topk, dim=-1
) # (b, topk)
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
).repeat_interleave(self.topk)
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
selected_input_index, :
]
topk_index = topk_index.reshape(-1, self.topk**2)
batch.input_ids = torch.gather(
topk_index, index=topk_cs_index, dim=1
).flatten()
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
self.scores = topk_cs_p
self.score_list.append(scores) # (b, topk, topk)
self.token_list.append(topk_index) # (b, topk * topk)
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
self.parents_list.append(
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
) # shape: (b, topk)
else:
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
batch.spec_info.hidden_states = (
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
)
batch.input_ids = topk_index.flatten()
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
self.scores = topk_p # shape: (b, topk)
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
self.token_list.append(topk_index) # shape: (b, topk)
self.origin_score_list.append(topk_p)
self.parents_list.append(
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(self.scores.shape[0], 1)
) # shape: (b, topk + 1)
self.cache_list.append(batch.out_cache_loc)
self.positions = (
batch.seq_lens[:, None]
+ torch.full(
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
)
).flatten()
bs = len(batch.seq_lens)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens + self.topk * self.iter,
batch.seq_lens + self.topk * (self.iter + 1),
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
self.iter += 1
def prepare_extend_after_decode(self, batch: ScheduleBatch):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist()
pt = 0
i = 0
for req in batch.reqs:
if req.finished():
continue
# assert seq_len - pre_len == req.extend_input_len
input_len = batch.extend_lens[i]
seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len
i += 1
assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
self.accept_length.add_(1)
create_extend_spec_info[(self.accept_length.numel(),)](
self.verified_id,
batch.seq_lens,
self.accept_length,
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
self.positions,
new_verified_id,
triton.next_power_of_2(self.spec_steps + 1),
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
batch.seq_lens_sum = sum(seq_lens_cpu)
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
def prepare_for_verify(self, batch: ScheduleBatch):
score_list = torch.cat(self.score_list, dim=1).flatten(
1
) # b, n, topk; n= 1+(self.iter-1)*self.topk
ss_token_list = torch.cat(
self.token_list, dim=1
) # b, (self.topk+(self.iter-1)*self.topk)
origin_token_list = torch.cat(self.origin_score_list, dim=1)
top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
scores = topk_cs_p # shape: (b, topk)
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
scores = torch.gather(origin_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1)
parent_list = torch.cat(self.parents_list[:-1], dim=1)
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
parent_list,
top_scores_index,
batch.seq_lens,
self.topk,
self.iter - 1,
self.num_verify_token,
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)
topk_index, # shape: (b, topk * topk)
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
)
return EagleVerifyInput(
draft_tokens.flatten(),
scores.flatten(),
tree_mask,
position,
retrive_index,
retrive_cum_len,
self.num_verify_token,
)
def generate_attn_arg_decode(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
req_to_token: torch.Tensor,
):
seq_num = req_pool_indices.numel()
bs = self.topk * req_pool_indices.numel()
seq_len = self.positions.reshape(-1).contiguous()
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0)
total_len = torch.sum(paged_kernel_lens).item()
kv_indices = torch.empty(
(total_len * self.topk + seq_num * self.iter * self.topk,),
dtype=torch.int32,
device="cuda",
)
generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](
req_pool_indices,
req_to_token,
paged_kernel_lens,
kv_indices,
self.iter,
self.topk,
req_to_token.shape[1],
triton.next_power_of_2(seq_num),
triton.next_power_of_2(self.spec_steps),
)
return bs, kv_indices, cum_kv_seq_len
def clear_draft_cache(self, batch):
draft_cache = torch.cat(self.cache_list, dim=0)
batch.token_to_kv_pool.free(draft_cache)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def merge_batch(self, spec_info: EAGLEDraftInput):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.sample_output = spec_info.sample_output
self.prev_mode = spec_info.prev_mode
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
return input_ids, hidden_states, scores, tree_info
class EagleVerifyInput(SpecInfo):
def __init__(
self,
draft_token: torch.Tensor,
draft_score: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_cum_len: torch.Tensor,
draft_token_num: int,
):
self.draft_token = draft_token
self.draft_score = draft_score
self.custom_mask = tree_mask
self.positions = positions
self.retrive_index = retrive_index
self.retrive_cum_len = retrive_cum_len
self.draft_token_num = draft_token_num
def prepare_for_verify(self, batch: ScheduleBatch):
batch.input_ids = self.draft_token
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
bs = batch.seq_lens.numel()
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
req_to_token: torch.Tensor,
):
batch_size = len(req_pool_indices)
qo_indptr = torch.arange(
0,
(1 + batch_size) * self.draft_token_num,
step=self.draft_token_num,
dtype=torch.int32,
device="cuda",
)
cum_kv_seq_len = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1
)
draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")],
dim=-1,
)
target_predict = predict[self.retrive_index]
candidates = draft_token[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
bs = self.retrive_cum_len.numel() - 1
max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.long, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
draft_input = EAGLEDraftInput()
new_accept_index = []
unfinished_index = []
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
has_finished = False
# iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
new_accept_index_ = []
for j, idx in enumerate(accept_index_row):
if idx == -1:
break
id = predict_cpu[idx]
# if not found_finished:
req.output_ids.append(id)
finished_extend_len[req.rid] = j + 1
req.check_finished()
if req.finished():
has_finished = True
# set all tokens after finished token to -1 and break
accept_index[i, j + 1 :] = -1
break
else:
new_accept_index_.append(idx)
if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i)
req.spec_verify_ct += 1
accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
batch.token_to_kv_pool.free(mem_need_free_idx)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda")
draft_input.verified_id = predict[new_accept_index]
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return (
draft_input,
logits_output,
verified_id,
finished_extend_len,
accept_length_cpu,
)
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
max_value, max_index = torch.max(values, dim=dim)
return max_value.unsqueeze(1), max_index.unsqueeze(1)
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)

View File

@@ -1,3 +1,5 @@
import logging
import time
from typing import List, Optional, Union
import torch
@@ -12,8 +14,18 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
from sglang.srt.utils import rank0_print
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
assign_draft_cache_locs,
fast_topk,
select_top_k_tokens,
)
logger = logging.getLogger(__name__)
class EAGLEWorker(TpModelWorker):
@@ -40,41 +52,47 @@ class EAGLEWorker(TpModelWorker):
is_draft_worker=True,
)
self.target_worker = target_worker
self.server_args = server_args
self.finish_extend_len = []
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.server_args = server_args
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.model_runner.init_cuda_graphs()
def forward_draft_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_for_decode(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
# Create multi-step attn backends and cuda graph runners
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
self.model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
if self.server_args.disable_cuda_graph:
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode():
# Draft
self._set_mem_pool(batch, self.model_runner)
for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
spec_info: EagleVerifyInput = self.draft(batch)
# Verify
(
@@ -84,8 +102,7 @@ class EAGLEWorker(TpModelWorker):
self.finish_extend_len,
accept_length_cpu,
model_worker_batch,
) = self.verify(batch)
next_draft_input.load_server_args(self.server_args)
) = self.verify(batch, spec_info)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None:
@@ -107,29 +124,145 @@ class EAGLEWorker(TpModelWorker):
)
# Forward with the draft model.
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.server_args)
spec_info.hidden_states = logits_output.hidden_states
spec_info.verified_id = next_token_ids
batch.spec_info = spec_info
batch.spec_info = EagleDraftInput(
hidden_states=logits_output.hidden_states,
verified_id=next_token_ids,
)
self.forward_draft_extend(batch)
return logits_output, next_token_ids, model_worker_batch, 0
def verify(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch)
verify_input.prepare_for_verify(batch)
def draft(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
# Parse args
num_seqs = batch.batch_size()
spec_info = batch.spec_info
# Allocate cache locations
out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps
)
assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
)
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch
)
if can_cuda_graph:
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
forward_batch
)
else:
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
ret = EagleVerifyInput.create(
spec_info.verified_id,
score_list,
token_list,
parents_list,
batch.seq_lens,
batch.seq_lens_sum,
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
)
# Free cache locations
batch.token_to_kv_pool.free(out_cache_loc)
self._set_mem_pool(batch, self.target_worker.model_runner)
return ret
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info = forward_batch.spec_info
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
spec_info.topk_index,
spec_info.hidden_states,
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
parents_list: List[torch.Tensor] = []
# Forward multiple steps
scores = None
for i in range(self.speculative_num_steps):
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
i, topk_p, topk_index, hidden_states, scores, self.topk
)
score_list.append(tree_info[0])
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
# Set inputs
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[
forward_batch.batch_size
* self.topk
* i : forward_batch.batch_size
* self.topk
* (i + 1)
]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
hidden_states = logits_output.hidden_states
return score_list, token_list, parents_list
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = verify_input
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
verify_input.hidden_states = logits_output.hidden_states
res = verify_input.verify(batch, logits_output)
spec_info.hidden_states = logits_output.hidden_states
res = spec_info.verify(batch, logits_output)
batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool
@@ -139,7 +272,7 @@ class EAGLEWorker(TpModelWorker):
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch)
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -155,13 +288,10 @@ class EAGLEWorker(TpModelWorker):
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
):
sample_output = torch.softmax(
logits_output.next_token_logits, dim=-1
) # TODO(kavioyu): Support more sampling methods
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
spec_info = forward_batch.spec_info
spec_info.sample_output = sample_output
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
spec_info.hidden_states = logits_output.hidden_states
spec_info.prev_mode = forward_batch.forward_mode
# Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]):