Open AI API hidden states (#6716)
This commit is contained in:
@@ -235,6 +235,10 @@ class CudaGraphRunner:
|
||||
self.model_runner.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||
if model_runner.server_args.enable_return_hidden_states:
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
@@ -342,11 +346,29 @@ class CudaGraphRunner:
|
||||
else True
|
||||
)
|
||||
|
||||
requested_capture_hidden_mode = max(
|
||||
forward_batch.capture_hidden_mode,
|
||||
(
|
||||
forward_batch.spec_info.capture_hidden_mode
|
||||
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
||||
is not None
|
||||
else CaptureHiddenMode.NULL
|
||||
),
|
||||
)
|
||||
capture_hidden_mode_matches = (
|
||||
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
||||
or requested_capture_hidden_mode == self.capture_hidden_mode
|
||||
)
|
||||
is_tbo_supported = (
|
||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||
)
|
||||
|
||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||
return (
|
||||
is_bs_supported
|
||||
and is_encoder_lens_supported
|
||||
and is_tbo_supported
|
||||
and capture_hidden_mode_matches
|
||||
)
|
||||
|
||||
def capture(self) -> None:
|
||||
profile_context = empty_context()
|
||||
@@ -541,21 +563,34 @@ class CudaGraphRunner:
|
||||
return graph, out
|
||||
|
||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
hidden_mode_from_spec_info = getattr(
|
||||
|
||||
# If the required capture_hidden_mode changes, we need to recapture the graph
|
||||
|
||||
# These are the different factors that can influence the capture_hidden_mode
|
||||
capture_hidden_mode_required_by_forward_batch = (
|
||||
forward_batch.capture_hidden_mode
|
||||
)
|
||||
capture_hidden_mode_required_by_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
if (
|
||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
):
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
self.capture()
|
||||
elif (
|
||||
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
||||
):
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
capture_hidden_mode_required_for_returning_hidden_states = (
|
||||
CaptureHiddenMode.FULL
|
||||
if self.model_runner.server_args.enable_return_hidden_states
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
|
||||
# Determine the highest capture_hidden_mode required
|
||||
# (If we have FULL, we can emulate LAST or NULL)
|
||||
# (If we have LAST, we can emulate NULL)
|
||||
required_capture_hidden_mode = max(
|
||||
capture_hidden_mode_required_by_forward_batch,
|
||||
capture_hidden_mode_required_by_spec_info,
|
||||
capture_hidden_mode_required_for_returning_hidden_states,
|
||||
)
|
||||
|
||||
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
||||
if self.capture_hidden_mode != required_capture_hidden_mode:
|
||||
self.capture_hidden_mode = required_capture_hidden_mode
|
||||
self.capture()
|
||||
|
||||
def replay_prepare(
|
||||
|
||||
@@ -31,6 +31,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from functools import total_ordering
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||
|
||||
|
||||
@total_ordering
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
# Do not capture anything.
|
||||
NULL = auto()
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = auto()
|
||||
NULL = 0
|
||||
# Capture a hidden state of the last token.
|
||||
LAST = auto()
|
||||
LAST = 1
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = 2
|
||||
|
||||
def need_capture(self):
|
||||
return self != CaptureHiddenMode.NULL
|
||||
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
|
||||
def is_last(self):
|
||||
return self == CaptureHiddenMode.LAST
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatch:
|
||||
|
||||
Reference in New Issue
Block a user