Misc fixes for eagle (flush_cache, CPU overhead) (#3014)

This commit is contained in:
Lianmin Zheng
2025-01-20 20:25:13 -08:00
parent d2571dd5c7
commit 287d07a669
11 changed files with 133 additions and 96 deletions

View File

@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
class EAGLEDraftInput(SpecInfo):
def __init__(self):
self.prev_mode = ForwardMode.DECODE
self.sample_output = None
self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = []
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
self.cache_list: List[torch.Tenor] = []
self.iter = 0
# shape: (b, hidden_size)
self.hidden_states: torch.Tensor = None
# shape: (b,)
self.verified_id: torch.Tensor = None
# shape: (b, vocab_size)
self.sample_output: torch.Tensor = None
self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None
self.has_finished: bool = False
self.unfinished_index: List[int] = None
self.accept_length_cpu: List[int] = None
def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
:pre_len
] = req.prefix_indices
batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
)
@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
self.cache_list.append(batch.out_cache_loc)
self.positions = (
batch.seq_lens[:, None]
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
+ torch.full(
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
)
).flatten()
bs = len(batch.seq_lens)
@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
def prepare_extend_after_decode(self, batch: ScheduleBatch):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
batch.extend_lens = (self.accept_length + 1).tolist()
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist()
pt = 0
seq_lens = batch.seq_lens.tolist()
i = 0
for req in batch.reqs:
if req.finished():
continue
# assert seq_len - pre_len == req.extend_input_len
input_len = self.accept_length[i] + 1
seq_len = seq_lens[i]
input_len = batch.extend_lens[i]
seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len
i += 1
assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
triton.next_power_of_2(self.spec_steps + 1),
)
batch.seq_lens_sum = sum(batch.seq_lens)
batch.seq_lens_sum = sum(seq_lens_cpu)
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
has_finished = False
# iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
finished_extend_len[req.rid] = j + 1
req.check_finished()
if req.finished():
draft_input.has_finished = True
has_finished = True
# set all tokens after finished token to -1 and break
accept_index[i, j + 1 :] = -1
break
@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo):
accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
verified_id_cpu = verified_id.tolist()
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo):
draft_input.verified_id = predict[new_accept_index]
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
draft_input.unfinished_index = unfinished_index
draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return (

View File

@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
from sglang.srt.utils import rank0_print
class EAGLEWorker(TpModelWorker):
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
def forward_draft_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_for_decode(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
batch.req_to_token_pool = runner.req_to_token_pool
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.spec_info.has_finished:
index = batch.spec_info.unfinished_index
seq_lens = batch.seq_lens
batch.seq_lens = batch.seq_lens[index]
batch.spec_info.prepare_extend_after_decode(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch)
batch.forward_mode = ForwardMode.DECODE
if batch.spec_info.has_finished:
batch.seq_lens = seq_lens
self._set_mem_pool(batch, self.target_worker.model_runner)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
):