Add return hidden state in the native API (#3897)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -120,7 +120,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# is very samll. We add more values here to make sure we capture the maximum bs.
|
||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||
capture_bs = list(
|
||||
sorted(
|
||||
set(
|
||||
@@ -175,6 +175,7 @@ class CudaGraphRunner:
|
||||
# Batch sizes to capture
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
self.capture_forward_mode = ForwardMode.DECODE
|
||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
self.num_tokens_per_bs = 1
|
||||
if model_runner.spec_algorithm.is_eagle():
|
||||
if self.model_runner.is_draft_worker:
|
||||
@@ -335,6 +336,10 @@ class CudaGraphRunner:
|
||||
gathered_buffer = None
|
||||
|
||||
spec_info = self.get_spec_info(num_tokens)
|
||||
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
||||
self.capture_hidden_mode = (
|
||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||
)
|
||||
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
@@ -355,15 +360,7 @@ class CudaGraphRunner:
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=(
|
||||
CaptureHiddenMode.FULL
|
||||
if self.model_runner.server_args.return_hidden_states
|
||||
else (
|
||||
spec_info.capture_hidden_mode
|
||||
if spec_info
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
)
|
||||
|
||||
# Attention backend
|
||||
@@ -406,6 +403,23 @@ class CudaGraphRunner:
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
hidden_mode_from_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
if (
|
||||
forward_batch.sampling_info.return_hidden_states
|
||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
):
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
self.capture()
|
||||
elif (
|
||||
not forward_batch.sampling_info.return_hidden_states
|
||||
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
||||
):
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
self.capture()
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user