Misc fixes for eagle (flush_cache, CPU overhead) (#3014)
This commit is contained in:
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
||||
from sglang.srt.utils import rank0_print
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def forward_draft_decode(self, batch: ScheduleBatch):
|
||||
batch.spec_info.prepare_for_decode(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
|
||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_to_token_pool = runner.req_to_token_pool
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
if batch.spec_info.has_finished:
|
||||
index = batch.spec_info.unfinished_index
|
||||
seq_lens = batch.seq_lens
|
||||
batch.seq_lens = batch.seq_lens[index]
|
||||
|
||||
batch.spec_info.prepare_extend_after_decode(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
|
||||
batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
if batch.spec_info.has_finished:
|
||||
batch.seq_lens = seq_lens
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user