Open AI API hidden states (#6716)
This commit is contained in:
@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepted,
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
|
||||
# Get forward batch
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.return_hidden_states = False
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||
|
||||
if batch.has_grammar:
|
||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||
@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user