Refactor: Move return_hidden_states to the generate input (#3985)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-03-01 20:51:29 -05:00
committed by GitHub
parent 18bb216c28
commit 40782f05d7
12 changed files with 54 additions and 44 deletions

View File

@@ -408,13 +408,13 @@ class CudaGraphRunner:
)
# If the capture_hidden_mode changes, we need to recapture the graph
if (
forward_batch.sampling_info.return_hidden_states
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
):
self.capture_hidden_mode = CaptureHiddenMode.FULL
self.capture()
elif (
not forward_batch.sampling_info.return_hidden_states
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
and self.capture_hidden_mode != hidden_mode_from_spec_info
):
self.capture_hidden_mode = hidden_mode_from_spec_info