Fix eagle3 cuda graph (#8163)
This commit is contained in:
@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
(
|
(
|
||||||
self.max_num_token,
|
self.max_num_token,
|
||||||
self.model_runner.model_config.hidden_size * 3,
|
(
|
||||||
|
self.model_runner.model_config.hf_config.target_hidden_size
|
||||||
|
* 3
|
||||||
|
if hasattr(
|
||||||
|
self.model_runner.model_config.hf_config,
|
||||||
|
"target_hidden_size",
|
||||||
|
)
|
||||||
|
else self.model_runner.model_config.hidden_size * 3
|
||||||
|
),
|
||||||
),
|
),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user