Open AI API hidden states (#6716)
This commit is contained in:
@@ -99,7 +99,7 @@ class GenerateReqInput:
|
||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
return_hidden_states: Union[List[bool], bool] = False
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
@@ -409,7 +409,11 @@ class GenerateReqInput:
|
||||
if self.custom_logit_processor is not None
|
||||
else None
|
||||
),
|
||||
return_hidden_states=self.return_hidden_states,
|
||||
return_hidden_states=(
|
||||
self.return_hidden_states[i]
|
||||
if isinstance(self.return_hidden_states, list)
|
||||
else self.return_hidden_states
|
||||
),
|
||||
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
||||
bootstrap_host=(
|
||||
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
||||
|
||||
@@ -418,6 +418,20 @@ class TokenizerManager:
|
||||
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
return_hidden_states = obj.return_hidden_states
|
||||
has_return_hidden_states = return_hidden_states == True or (
|
||||
isinstance(return_hidden_states, list) and any(return_hidden_states)
|
||||
)
|
||||
if (
|
||||
not self.server_args.enable_return_hidden_states
|
||||
and has_return_hidden_states
|
||||
):
|
||||
raise ValueError(
|
||||
"return_hidden_states=True requires the server to be started "
|
||||
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
|
||||
)
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
|
||||
@@ -235,6 +235,10 @@ class CudaGraphRunner:
|
||||
self.model_runner.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||
if model_runner.server_args.enable_return_hidden_states:
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
@@ -342,11 +346,29 @@ class CudaGraphRunner:
|
||||
else True
|
||||
)
|
||||
|
||||
requested_capture_hidden_mode = max(
|
||||
forward_batch.capture_hidden_mode,
|
||||
(
|
||||
forward_batch.spec_info.capture_hidden_mode
|
||||
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
||||
is not None
|
||||
else CaptureHiddenMode.NULL
|
||||
),
|
||||
)
|
||||
capture_hidden_mode_matches = (
|
||||
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
||||
or requested_capture_hidden_mode == self.capture_hidden_mode
|
||||
)
|
||||
is_tbo_supported = (
|
||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||
)
|
||||
|
||||
return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
|
||||
return (
|
||||
is_bs_supported
|
||||
and is_encoder_lens_supported
|
||||
and is_tbo_supported
|
||||
and capture_hidden_mode_matches
|
||||
)
|
||||
|
||||
def capture(self) -> None:
|
||||
profile_context = empty_context()
|
||||
@@ -541,21 +563,34 @@ class CudaGraphRunner:
|
||||
return graph, out
|
||||
|
||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
hidden_mode_from_spec_info = getattr(
|
||||
|
||||
# If the required capture_hidden_mode changes, we need to recapture the graph
|
||||
|
||||
# These are the different factors that can influence the capture_hidden_mode
|
||||
capture_hidden_mode_required_by_forward_batch = (
|
||||
forward_batch.capture_hidden_mode
|
||||
)
|
||||
capture_hidden_mode_required_by_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
if (
|
||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
):
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
self.capture()
|
||||
elif (
|
||||
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
||||
):
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
capture_hidden_mode_required_for_returning_hidden_states = (
|
||||
CaptureHiddenMode.FULL
|
||||
if self.model_runner.server_args.enable_return_hidden_states
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
|
||||
# Determine the highest capture_hidden_mode required
|
||||
# (If we have FULL, we can emulate LAST or NULL)
|
||||
# (If we have LAST, we can emulate NULL)
|
||||
required_capture_hidden_mode = max(
|
||||
capture_hidden_mode_required_by_forward_batch,
|
||||
capture_hidden_mode_required_by_spec_info,
|
||||
capture_hidden_mode_required_for_returning_hidden_states,
|
||||
)
|
||||
|
||||
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
||||
if self.capture_hidden_mode != required_capture_hidden_mode:
|
||||
self.capture_hidden_mode = required_capture_hidden_mode
|
||||
self.capture()
|
||||
|
||||
def replay_prepare(
|
||||
|
||||
@@ -31,6 +31,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from functools import total_ordering
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||
|
||||
|
||||
@total_ordering
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
# Do not capture anything.
|
||||
NULL = auto()
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = auto()
|
||||
NULL = 0
|
||||
# Capture a hidden state of the last token.
|
||||
LAST = auto()
|
||||
LAST = 1
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = 2
|
||||
|
||||
def need_capture(self):
|
||||
return self != CaptureHiddenMode.NULL
|
||||
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
|
||||
def is_last(self):
|
||||
return self == CaptureHiddenMode.LAST
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardBatch:
|
||||
|
||||
@@ -542,6 +542,7 @@ def v1_generate_request(
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
lora_paths = []
|
||||
return_hidden_states = []
|
||||
|
||||
for request in all_requests:
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
@@ -588,6 +589,7 @@ def v1_generate_request(
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
return_hidden_states.append(request.return_hidden_states)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
@@ -599,6 +601,7 @@ def v1_generate_request(
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
lora_paths = lora_paths[0]
|
||||
return_hidden_states = return_hidden_states[0]
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
@@ -615,6 +618,7 @@ def v1_generate_request(
|
||||
stream=all_requests[0].stream,
|
||||
rid=request_ids,
|
||||
lora_path=lora_paths,
|
||||
return_hidden_states=return_hidden_states,
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
@@ -683,6 +687,16 @@ def v1_generate_response(
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
hidden_states = None
|
||||
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||
elif (not isinstance(request, list)) and request.return_hidden_states:
|
||||
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
||||
if hidden_states is not None:
|
||||
hidden_states = (
|
||||
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
|
||||
)
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
if to_file:
|
||||
@@ -698,6 +712,8 @@ def v1_generate_response(
|
||||
else None
|
||||
),
|
||||
}
|
||||
if hidden_states is not None:
|
||||
choice_data["hidden_states"] = hidden_states
|
||||
else:
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=idx,
|
||||
@@ -709,6 +725,7 @@ def v1_generate_response(
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
@@ -777,6 +794,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = {}
|
||||
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
@@ -791,6 +809,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
||||
hidden_states[index] = content["meta_info"].get(
|
||||
"hidden_states", None
|
||||
) or hidden_states.get(index)
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
@@ -873,6 +894,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
if request.return_hidden_states and hidden_states:
|
||||
for index, choice_hidden_states in hidden_states.items():
|
||||
last_token_hidden_states = (
|
||||
choice_hidden_states[-1]
|
||||
if choice_hidden_states and len(choice_hidden_states) > 1
|
||||
else []
|
||||
)
|
||||
hidden_states_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
text="",
|
||||
index=index,
|
||||
hidden_states=last_token_hidden_states,
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
total_prompt_tokens = sum(
|
||||
tokens
|
||||
@@ -973,6 +1015,7 @@ def v1_chat_generate_request(
|
||||
top_logprobs_nums = []
|
||||
modalities_list = []
|
||||
lora_paths = []
|
||||
return_hidden_states = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
|
||||
@@ -1215,6 +1258,7 @@ def v1_chat_generate_request(
|
||||
image_data_list.append(image_data)
|
||||
audio_data_list.append(audio_data)
|
||||
modalities_list.append(modalities)
|
||||
return_hidden_states.append(request.return_hidden_states)
|
||||
if len(all_requests) == 1:
|
||||
if is_multimodal:
|
||||
# processor will need text input
|
||||
@@ -1233,6 +1277,7 @@ def v1_chat_generate_request(
|
||||
modalities_list = modalities_list[0]
|
||||
lora_paths = lora_paths[0]
|
||||
request_ids = request_ids[0]
|
||||
return_hidden_states = return_hidden_states[0]
|
||||
else:
|
||||
if tokenizer_manager.model_config.is_multimodal:
|
||||
# processor will need text input
|
||||
@@ -1259,6 +1304,7 @@ def v1_chat_generate_request(
|
||||
bootstrap_host=all_requests[0].bootstrap_host,
|
||||
bootstrap_port=all_requests[0].bootstrap_port,
|
||||
bootstrap_room=all_requests[0].bootstrap_room,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
@@ -1319,6 +1365,20 @@ def v1_chat_generate_response(
|
||||
else:
|
||||
choice_logprobs = None
|
||||
|
||||
if isinstance(request, list) and request[idx].return_hidden_states:
|
||||
include_hidden_states = True
|
||||
elif not isinstance(request, list) and request.return_hidden_states:
|
||||
include_hidden_states = True
|
||||
else:
|
||||
include_hidden_states = False
|
||||
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
|
||||
hidden_states = ret_item["meta_info"]["hidden_states"]
|
||||
hidden_states = (
|
||||
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
|
||||
)
|
||||
else:
|
||||
hidden_states = None
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
tool_calls = None
|
||||
@@ -1391,6 +1451,8 @@ def v1_chat_generate_response(
|
||||
else None
|
||||
),
|
||||
}
|
||||
if hidden_states is not None:
|
||||
choice_data["hidden_states"] = hidden_states
|
||||
else:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
@@ -1407,6 +1469,7 @@ def v1_chat_generate_response(
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
@@ -1486,12 +1549,16 @@ async def v1_chat_completions(
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
cached_tokens = {}
|
||||
hidden_states = {}
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
text = content["text"]
|
||||
hidden_states[index] = content["meta_info"].get(
|
||||
"hidden_states", None
|
||||
) or hidden_states.get(index)
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
@@ -1613,6 +1680,7 @@ async def v1_chat_completions(
|
||||
if (delta and len(delta) == 0) or not delta:
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
continue
|
||||
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
@@ -1702,6 +1770,7 @@ async def v1_chat_completions(
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
else:
|
||||
# No tool calls => just treat this as normal text
|
||||
@@ -1734,6 +1803,7 @@ async def v1_chat_completions(
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
if finish_reason_type == "stop" and request.tool_choice != "none":
|
||||
parser = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
@@ -1769,6 +1839,28 @@ async def v1_chat_completions(
|
||||
|
||||
else:
|
||||
usage = None
|
||||
if request.return_hidden_states and hidden_states:
|
||||
for index, choice_hidden_states in hidden_states.items():
|
||||
last_token_hidden_states = (
|
||||
choice_hidden_states[-1]
|
||||
if choice_hidden_states and len(choice_hidden_states) > 1
|
||||
else []
|
||||
)
|
||||
hidden_states_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(
|
||||
hidden_states=last_token_hidden_states
|
||||
),
|
||||
finish_reason=finish_reason_type,
|
||||
)
|
||||
],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
created=created,
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pydantic import BaseModel, Field, model_serializer, root_validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
return_hidden_states: Optional[bool] = False
|
||||
|
||||
# For PD disaggregation
|
||||
bootstrap_host: Optional[str] = None
|
||||
@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel):
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
# Hidden States
|
||||
return_hidden_states: Optional[bool] = False
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return exclude_if_none(self, ["hidden_states"])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
@@ -513,3 +537,8 @@ class ScoringResponse(BaseModel):
|
||||
model: str
|
||||
usage: Optional[UsageInfo] = None
|
||||
object: str = "scoring"
|
||||
|
||||
|
||||
def exclude_if_none(obj, field_names: List[str]):
|
||||
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
||||
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
||||
|
||||
@@ -215,6 +215,7 @@ class ServerArgs:
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
disable_fast_image_processor: bool = False
|
||||
warmups: Optional[str] = None
|
||||
enable_return_hidden_states: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -1456,6 +1457,12 @@ class ServerArgs:
|
||||
default=ServerArgs.debug_tensor_dump_inject,
|
||||
help="Inject the outputs from jax as the input of every layer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-return-hidden-states",
|
||||
action="store_true",
|
||||
help="Enable returning hidden states with responses.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-tensor-dump-prefill-only",
|
||||
action="store_true",
|
||||
|
||||
@@ -117,9 +117,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
hidden_states = self.hidden_states[:num_seqs]
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
topk_p=topk_p,
|
||||
topk_index=topk_index,
|
||||
hidden_states=hidden_states,
|
||||
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
|
||||
)
|
||||
|
||||
# Forward batch
|
||||
|
||||
@@ -290,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepted,
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
@@ -431,10 +432,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
|
||||
# Get forward batch
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -547,11 +548,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.return_hidden_states = False
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||
|
||||
if batch.has_grammar:
|
||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||
@@ -687,15 +690,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -718,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user