Open AI API hidden states (#6716)

This commit is contained in:
kyle-pena-kuzco
2025-06-10 17:37:29 -04:00
committed by GitHub
parent ce5ee3bdf0
commit b56de8f943
17 changed files with 606 additions and 44 deletions

View File

@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker):
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
)
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
batch,
self.speculative_num_steps,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)