Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -408,13 +408,13 @@ class CudaGraphRunner:
|
||||
)
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
if (
|
||||
forward_batch.sampling_info.return_hidden_states
|
||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
):
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
self.capture()
|
||||
elif (
|
||||
not forward_batch.sampling_info.return_hidden_states
|
||||
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
||||
):
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
|
||||
Reference in New Issue
Block a user