Clean up eagle code (#2756)
This commit is contained in:
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info.prepare_for_decode(batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
|
||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||
self._swap_mem_pool(batch, self.model_runner)
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||
if batch.forward_mode.is_decode():
|
||||
prev_spec_info = batch.spec_info
|
||||
self._swap_mem_pool(batch, self.model_runner)
|
||||
# Draft
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
for i in range(self.server_args.speculative_num_steps):
|
||||
self.forward_draft_decode(batch)
|
||||
batch.spec_info.clear_draft_cache(batch)
|
||||
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
# Verify
|
||||
(
|
||||
next_draft_input,
|
||||
logits_output,
|
||||
verified_id,
|
||||
self.finish_extend_len,
|
||||
accept_length_cpu,
|
||||
model_worker_batch,
|
||||
) = self.verify(batch)
|
||||
next_draft_input.init(self.server_args)
|
||||
next_draft_input.load_server_args(self.server_args)
|
||||
batch.spec_info = next_draft_input
|
||||
# if it is None, means all requsets are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
self.forward_extend_after_decode(batch)
|
||||
batch.spec_info = prev_spec_info
|
||||
return logits_output, verified_id, model_worker_batch, next_draft_input
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
return (
|
||||
logits_output,
|
||||
verified_id,
|
||||
model_worker_batch,
|
||||
sum(accept_length_cpu),
|
||||
)
|
||||
|
||||
else:
|
||||
spec_info = EAGLEDraftInput()
|
||||
spec_info.init(self.server_args)
|
||||
# Forward with the target model and get hidden states.
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.spec_info = spec_info
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
model_worker_batch.spec_info.verified_id = next_token_ids
|
||||
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
|
||||
# Forward with the draft model.
|
||||
spec_info = EAGLEDraftInput()
|
||||
spec_info.load_server_args(self.server_args)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
spec_info.verified_id = next_token_ids
|
||||
batch.spec_info = spec_info
|
||||
self.forward_draft_extend(batch)
|
||||
batch.spec_info = None
|
||||
return logits_output, next_token_ids, model_worker_batch, spec_info
|
||||
return logits_output, next_token_ids, model_worker_batch, 0
|
||||
|
||||
def verify(self, batch: ScheduleBatch):
|
||||
verify_input = batch.spec_info.prepare_for_verify(batch)
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
verify_input.prepare_for_verify(batch)
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = verify_input
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
return res + (model_worker_batch,)
|
||||
|
||||
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||
batch.token_to_kv_pool = runner.token_to_kv_pool
|
||||
batch.req_to_token_pool = runner.req_to_token_pool
|
||||
|
||||
def forward_extend_after_decode(self, batch: ScheduleBatch):
|
||||
self._swap_mem_pool(batch, self.model_runner)
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
if batch.spec_info.has_finished:
|
||||
index = batch.spec_info.unfinished_index
|
||||
seq_lens = batch.seq_lens
|
||||
batch.seq_lens = batch.seq_lens[index]
|
||||
|
||||
batch.spec_info.prepare_extend_after_decode(batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
|
||||
batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
if batch.spec_info.has_finished:
|
||||
batch.seq_lens = seq_lens
|
||||
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
def capture_for_decode(self, logits_output, forward_batch):
|
||||
if isinstance(logits_output, LogitsProcessorOutput):
|
||||
logits = logits_output.next_token_logits
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
):
|
||||
sample_output = torch.softmax(
|
||||
logits, dim=-1
|
||||
) # TODO: Support more sampling method @kavioyu
|
||||
forward_batch.spec_info.capture_for_decode(
|
||||
sample_output, logits_output.hidden_states, forward_batch.forward_mode
|
||||
)
|
||||
logits_output.next_token_logits, dim=-1
|
||||
) # TODO(kavioyu): Support more sampling methods
|
||||
spec_info = forward_batch.spec_info
|
||||
spec_info.sample_output = sample_output
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
spec_info.prev_mode = forward_batch.forward_mode
|
||||
|
||||
# Don't support prefix share now.
|
||||
def finish_request(self, reqs: Union[Req, List[Req]]):
|
||||
|
||||
Reference in New Issue
Block a user