Fix dtype for idle input in spec decoding (#7456)
This commit is contained in:
@@ -89,14 +89,13 @@ class EagleDraftInput:
|
||||
cls,
|
||||
device: torch.device,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
topk: int,
|
||||
capture_hidden_mode: CaptureHiddenMode,
|
||||
):
|
||||
return cls(
|
||||
verified_id=None,
|
||||
hidden_states=torch.empty(
|
||||
(0, hidden_size), device=device, dtype=torch.float32
|
||||
),
|
||||
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
||||
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||
capture_hidden_mode=capture_hidden_mode,
|
||||
@@ -334,6 +333,7 @@ class EagleVerifyInput:
|
||||
draft_input=EagleDraftInput.create_idle_input(
|
||||
device=batch.device,
|
||||
hidden_size=batch.model_config.hidden_size,
|
||||
dtype=batch.model_config.dtype,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
),
|
||||
|
||||
@@ -498,6 +498,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=self.model_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
@@ -838,6 +839,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||
device=self.device,
|
||||
hidden_size=self.model_config.hidden_size,
|
||||
dtype=self.model_config.dtype,
|
||||
topk=self.topk,
|
||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user