Open AI API hidden states (#6716)

This commit is contained in:
kyle-pena-kuzco
2025-06-10 17:37:29 -04:00
committed by GitHub
parent ce5ee3bdf0
commit b56de8f943
17 changed files with 606 additions and 44 deletions

View File

@@ -235,6 +235,10 @@ class CudaGraphRunner:
self.model_runner.server_args.speculative_num_draft_tokens
)
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if model_runner.server_args.enable_return_hidden_states:
self.capture_hidden_mode = CaptureHiddenMode.FULL
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
@@ -342,11 +346,29 @@ class CudaGraphRunner:
else True
)
requested_capture_hidden_mode = max(
forward_batch.capture_hidden_mode,
(
forward_batch.spec_info.capture_hidden_mode
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
is not None
else CaptureHiddenMode.NULL
),
)
capture_hidden_mode_matches = (
requested_capture_hidden_mode == CaptureHiddenMode.NULL
or requested_capture_hidden_mode == self.capture_hidden_mode
)
is_tbo_supported = (
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
return (
is_bs_supported
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
)
def capture(self) -> None:
profile_context = empty_context()
@@ -541,21 +563,34 @@ class CudaGraphRunner:
return graph, out
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr(
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch = (
forward_batch.capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
):
self.capture_hidden_mode = CaptureHiddenMode.FULL
self.capture()
elif (
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
and self.capture_hidden_mode != hidden_mode_from_spec_info
):
self.capture_hidden_mode = hidden_mode_from_spec_info
capture_hidden_mode_required_for_returning_hidden_states = (
CaptureHiddenMode.FULL
if self.model_runner.server_args.enable_return_hidden_states
else CaptureHiddenMode.NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode = max(
capture_hidden_mode_required_by_forward_batch,
capture_hidden_mode_required_by_spec_info,
capture_hidden_mode_required_for_returning_hidden_states,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if self.capture_hidden_mode != required_capture_hidden_mode:
self.capture_hidden_mode = required_capture_hidden_mode
self.capture()
def replay_prepare(

View File

@@ -31,6 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from functools import total_ordering
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
@total_ordering
class CaptureHiddenMode(IntEnum):
# Do not capture anything.
NULL = auto()
# Capture hidden states of all tokens.
FULL = auto()
NULL = 0
# Capture a hidden state of the last token.
LAST = auto()
LAST = 1
# Capture hidden states of all tokens.
FULL = 2
def need_capture(self):
return self != CaptureHiddenMode.NULL
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
def is_last(self):
return self == CaptureHiddenMode.LAST
def __lt__(self, other):
return self.value < other.value
@dataclass
class ForwardBatch: