Misc fixes for eagle (flush_cache, CPU overhead) (#3014)
This commit is contained in:
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
|
||||
class EAGLEDraftInput(SpecInfo):
|
||||
def __init__(self):
|
||||
self.prev_mode = ForwardMode.DECODE
|
||||
self.sample_output = None
|
||||
|
||||
self.scores: torch.Tensor = None
|
||||
self.score_list: List[torch.Tensor] = []
|
||||
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
|
||||
self.cache_list: List[torch.Tenor] = []
|
||||
self.iter = 0
|
||||
|
||||
# shape: (b, hidden_size)
|
||||
self.hidden_states: torch.Tensor = None
|
||||
# shape: (b,)
|
||||
self.verified_id: torch.Tensor = None
|
||||
# shape: (b, vocab_size)
|
||||
self.sample_output: torch.Tensor = None
|
||||
|
||||
self.positions: torch.Tensor = None
|
||||
self.accept_length: torch.Tensor = None
|
||||
self.has_finished: bool = False
|
||||
self.unfinished_index: List[int] = None
|
||||
self.accept_length_cpu: List[int] = None
|
||||
|
||||
def load_server_args(self, server_args: ServerArgs):
|
||||
self.topk: int = server_args.speculative_eagle_topk
|
||||
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
|
||||
:pre_len
|
||||
] = req.prefix_indices
|
||||
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
|
||||
@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
|
||||
self.cache_list.append(batch.out_cache_loc)
|
||||
self.positions = (
|
||||
batch.seq_lens[:, None]
|
||||
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
||||
+ torch.full(
|
||||
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
|
||||
)
|
||||
).flatten()
|
||||
|
||||
bs = len(batch.seq_lens)
|
||||
@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
|
||||
|
||||
def prepare_extend_after_decode(self, batch: ScheduleBatch):
|
||||
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
||||
batch.extend_lens = (self.accept_length + 1).tolist()
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
seq_lens_cpu = batch.seq_lens.tolist()
|
||||
|
||||
pt = 0
|
||||
seq_lens = batch.seq_lens.tolist()
|
||||
|
||||
i = 0
|
||||
|
||||
for req in batch.reqs:
|
||||
if req.finished():
|
||||
continue
|
||||
# assert seq_len - pre_len == req.extend_input_len
|
||||
input_len = self.accept_length[i] + 1
|
||||
seq_len = seq_lens[i]
|
||||
input_len = batch.extend_lens[i]
|
||||
seq_len = seq_lens_cpu[i]
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
seq_len - input_len : seq_len
|
||||
] = batch.out_cache_loc[pt : pt + input_len]
|
||||
pt += input_len
|
||||
i += 1
|
||||
assert pt == batch.out_cache_loc.shape[0]
|
||||
|
||||
self.positions = torch.empty_like(self.verified_id)
|
||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
||||
@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
|
||||
triton.next_power_of_2(self.spec_steps + 1),
|
||||
)
|
||||
|
||||
batch.seq_lens_sum = sum(batch.seq_lens)
|
||||
batch.seq_lens_sum = sum(seq_lens_cpu)
|
||||
batch.input_ids = self.verified_id
|
||||
self.verified_id = new_verified_id
|
||||
|
||||
@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
|
||||
finished_extend_len = {} # {rid:accept_length + 1}
|
||||
accept_index_cpu = accept_index.tolist()
|
||||
predict_cpu = predict.tolist()
|
||||
has_finished = False
|
||||
|
||||
# iterate every accepted token and check if req has finished after append the token
|
||||
# should be checked BEFORE free kv cache slots
|
||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||
@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
|
||||
finished_extend_len[req.rid] = j + 1
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
draft_input.has_finished = True
|
||||
has_finished = True
|
||||
# set all tokens after finished token to -1 and break
|
||||
accept_index[i, j + 1 :] = -1
|
||||
break
|
||||
@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo):
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
verified_id = predict[accept_index]
|
||||
verified_id_cpu = verified_id.tolist()
|
||||
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo):
|
||||
draft_input.verified_id = predict[new_accept_index]
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
||||
draft_input.accept_length = accept_length[unfinished_index]
|
||||
draft_input.unfinished_index = unfinished_index
|
||||
draft_input.accept_length_cpu = [
|
||||
accept_length_cpu[i] for i in unfinished_index
|
||||
]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||
return (
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
||||
from sglang.srt.utils import rank0_print
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_draft_decode(self, batch: ScheduleBatch):
|
||||
batch.spec_info.prepare_for_decode(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
|
||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_to_token_pool = runner.req_to_token_pool
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
if batch.spec_info.has_finished:
|
||||
index = batch.spec_info.unfinished_index
|
||||
seq_lens = batch.seq_lens
|
||||
batch.seq_lens = batch.seq_lens[index]
|
||||
|
||||
batch.spec_info.prepare_extend_after_decode(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
|
||||
batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
if batch.spec_info.has_finished:
|
||||
batch.seq_lens = seq_lens
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user