Fix eagle3 cuda graph (#8163)
This commit is contained in:
@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.hidden_states = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
self.model_runner.model_config.hidden_size * 3,
|
||||
(
|
||||
self.model_runner.model_config.hf_config.target_hidden_size
|
||||
* 3
|
||||
if hasattr(
|
||||
self.model_runner.model_config.hf_config,
|
||||
"target_hidden_size",
|
||||
)
|
||||
else self.model_runner.model_config.hidden_size * 3
|
||||
),
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user