Clean up eagle code (#2756)

This commit is contained in:
Lianmin Zheng
2025-01-06 14:54:18 -08:00
committed by GitHub
parent 2855caa481
commit b8574f6953
7 changed files with 138 additions and 128 deletions

View File

@@ -9,12 +9,11 @@ import triton.language as tl
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
from sglang.srt.speculative.spec_info import SpecInfo
if TYPE_CHECKING:
from python.sglang.srt.layers.sampler import SampleOutput
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
class EAGLEDraftInput(SpecInfo):
hidden_states: torch.Tensor = None
verified_id: torch.Tensor = None
positions: torch.Tensor = None
accept_length: torch.Tensor = None
has_finished: bool = False
unfinished_index: List[int] = None
def init(self, server_args: ServerArgs):
def __init__(self):
self.prev_mode = ForwardMode.DECODE
self.sample_output = None
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = []
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
self.parents_list: List[torch.Tensor] = []
self.cache_list: List[torch.Tenor] = []
self.iter = 0
self.root_token: int = None
assert self.topk <= 10, "topk should <= 10"
self.hidden_states: torch.Tensor = None
self.verified_id: torch.Tensor = None
self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None
self.has_finished: bool = False
self.unfinished_index: List[int] = None
def prepare_for_extend(self, batch: ForwardBatch):
def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk
self.num_verify_token: int = server_args.speculative_num_draft_tokens
self.spec_steps = server_args.speculative_num_steps
def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
batch.out_cache_loc = out_cache_loc
@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
pt += req.extend_input_len
seq_lens = [0] + batch.extend_lens
input_ids = batch.input_ids.tolist()
verified_id = batch.spec_info.verified_id.tolist()
model_input_ids = []
for i in range(len(seq_lens) - 1):
model_input_ids.extend(
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
)
batch.input_ids = torch.tensor(
model_input_ids, dtype=torch.int32, device="cuda"
)
def capture_for_decode(
self,
sample_output: SampleOutput,
hidden_states: torch.Tensor,
prev_mode: ForwardMode,
):
self.sample_output = sample_output
self.prev_mode = prev_mode
self.hidden_states = hidden_states
# TODO: support batching inputs
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # b * (1/topk), vocab
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1)
topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
if self.prev_mode == ForwardMode.DECODE:
topk_index, topk_p = (
top.indices,
top.values,
) # shape: (b * top_k, top_k) or (b, top_k)
if self.prev_mode.is_decode():
scores = torch.mul(
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs = torch.topk(
scores.flatten(start_dim=1), self.topk, dim=-1
) # (b, topk)
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
self.scores = topk_cs_p
selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
selected_input_index = (
topk_cs_index.flatten() // self.topk
) # shape: (b * topk)
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
selected_input_index, :
]
topk_index = topk_index.reshape(-1, self.topk**2)
batch.input_ids = torch.gather(
topk_index, index=topk_cs_index, dim=1
).flatten()
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
self.score_list.append(scores) # b, topk, topk
self.token_list.append(topk_index) # b, topk*topk
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
self.scores = topk_cs_p
self.score_list.append(scores) # (b, topk, topk)
self.token_list.append(topk_index) # (b, topk * topk)
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
self.parents_list.append(
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
) # b, topk
elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
self.scores = topk_p # b, top_k
self.score_list.append(topk_p.unsqueeze(1))
self.token_list.append(topk_index)
self.origin_score_list.append(topk_p)
) # shape: (b, topk)
else:
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
batch.spec_info.hidden_states = (
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
)
batch.input_ids = topk_index.flatten()
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
self.scores = topk_p # shape: (b, topk)
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
self.token_list.append(topk_index) # shape: (b, topk)
self.origin_score_list.append(topk_p)
self.parents_list.append(
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(self.scores.shape[0], 1)
) # b, topk+1
) # shape: (b, topk + 1)
self.cache_list.append(batch.out_cache_loc)
self.positions = (
batch.seq_lens[:, None]
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
).flatten()
bs = batch.seq_lens.numel()
bs = len(batch.seq_lens)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
)
return bs, kv_indices, cum_kv_seq_len
def clear(self):
self.iter = 0
self.score_list.clear()
self.positions = None
def clear_draft_cache(self, batch):
draft_cache = torch.cat(self.cache_list, dim=0)
batch.token_to_kv_pool.free(draft_cache)
@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
)
accept_index = accept_index[accept_index != -1]
# extract_index = extract_index[extract_index != 0]
draft_input = EAGLEDraftInput()
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
# retracted_reqs, new_token_ratio = batch.retract_decode()
low = 0
draft_input = EAGLEDraftInput()
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
req.check_finished()
@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
draft_input.unfinished_index = unfinished_index
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return draft_input, logits_output, verified_id, finished_extend_len
return (
draft_input,
logits_output,
verified_id,
finished_extend_len,
accept_length_cpu,
)

View File

@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_for_decode(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner)
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._swap_mem_pool(batch, self.target_worker.model_runner)
self._set_mem_pool(batch, self.target_worker.model_runner)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode():
prev_spec_info = batch.spec_info
self._swap_mem_pool(batch, self.model_runner)
# Draft
self._set_mem_pool(batch, self.model_runner)
for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch)
self._swap_mem_pool(batch, self.target_worker.model_runner)
self._set_mem_pool(batch, self.target_worker.model_runner)
# Verify
(
next_draft_input,
logits_output,
verified_id,
self.finish_extend_len,
accept_length_cpu,
model_worker_batch,
) = self.verify(batch)
next_draft_input.init(self.server_args)
next_draft_input.load_server_args(self.server_args)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None:
self.forward_extend_after_decode(batch)
batch.spec_info = prev_spec_info
return logits_output, verified_id, model_worker_batch, next_draft_input
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verified_id,
model_worker_batch,
sum(accept_length_cpu),
)
else:
spec_info = EAGLEDraftInput()
spec_info.init(self.server_args)
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_info = spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
model_worker_batch.spec_info.verified_id = next_token_ids
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
# Forward with the draft model.
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.server_args)
spec_info.hidden_states = logits_output.hidden_states
spec_info.verified_id = next_token_ids
batch.spec_info = spec_info
self.forward_draft_extend(batch)
batch.spec_info = None
return logits_output, next_token_ids, model_worker_batch, spec_info
return logits_output, next_token_ids, model_worker_batch, 0
def verify(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
verify_input.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = verify_input
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch = batch.get_model_worker_batch()
@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,)
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool
def forward_extend_after_decode(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.spec_info.has_finished:
index = batch.spec_info.unfinished_index
seq_lens = batch.seq_lens
batch.seq_lens = batch.seq_lens[index]
batch.spec_info.prepare_extend_after_decode(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch)
batch.forward_mode = ForwardMode.DECODE
if batch.spec_info.has_finished:
batch.seq_lens = seq_lens
self._swap_mem_pool(batch, self.target_worker.model_runner)
self._set_mem_pool(batch, self.target_worker.model_runner)
def capture_for_decode(self, logits_output, forward_batch):
if isinstance(logits_output, LogitsProcessorOutput):
logits = logits_output.next_token_logits
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
):
sample_output = torch.softmax(
logits, dim=-1
) # TODO: Support more sampling method @kavioyu
forward_batch.spec_info.capture_for_decode(
sample_output, logits_output.hidden_states, forward_batch.forward_mode
)
logits_output.next_token_logits, dim=-1
) # TODO(kavioyu): Support more sampling methods
spec_info = forward_batch.spec_info
spec_info.sample_output = sample_output
spec_info.hidden_states = logits_output.hidden_states
spec_info.prev_mode = forward_batch.forward_mode
# Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]):