[Feature] Support EAGLE 3 (#4247)
This commit is contained in:
@@ -220,7 +220,19 @@ class CudaGraphRunner:
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||
|
||||
# Speculative_inference
|
||||
if model_runner.spec_algorithm.is_eagle():
|
||||
if (
|
||||
model_runner.spec_algorithm.is_eagle3()
|
||||
and not model_runner.is_draft_worker
|
||||
):
|
||||
self.hidden_states = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
3 * self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.model_runner.model.set_eagle3_layers_to_capture()
|
||||
elif model_runner.spec_algorithm.is_eagle():
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
||||
dtype=self.model_runner.dtype,
|
||||
|
||||
@@ -210,6 +210,10 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = None
|
||||
self.init_attention_backend()
|
||||
|
||||
# auxiliary hidden capture mode. TODO: expose this to server args?
|
||||
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
||||
self.model.set_eagle3_layers_to_capture()
|
||||
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user