[Feature] Support EAGLE 3 (#4247)

This commit is contained in:
James Liu
2025-03-18 10:35:23 -04:00
committed by GitHub
parent 8baf9a0c18
commit 9e0186f352
11 changed files with 385 additions and 22 deletions

View File

@@ -220,7 +220,19 @@ class CudaGraphRunner:
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
# Speculative_inference
if model_runner.spec_algorithm.is_eagle():
if (
model_runner.spec_algorithm.is_eagle3()
and not model_runner.is_draft_worker
):
self.hidden_states = torch.zeros(
(
self.max_num_token,
3 * self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.model_runner.model.set_eagle3_layers_to_capture()
elif model_runner.spec_algorithm.is_eagle():
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,

View File

@@ -210,6 +210,10 @@ class ModelRunner:
self.cuda_graph_runner = None
self.init_attention_backend()
# auxiliary hidden capture mode. TODO: expose this to server args?
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
self.model.set_eagle3_layers_to_capture()
def model_specific_adjustment(self):
server_args = self.server_args