Add return hidden state in the native API (#3897)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -55,6 +55,7 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan
|
|||||||
* `ignore_eos`: Don't stop generation when EOS token is sampled.
|
* `ignore_eos`: Don't stop generation when EOS token is sampled.
|
||||||
* `skip_special_tokens`: Remove special tokens during decoding.
|
* `skip_special_tokens`: Remove special tokens during decoding.
|
||||||
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below.
|
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below.
|
||||||
|
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information.
|
||||||
|
|
||||||
|
|
||||||
### Custom Logit Processor
|
### Custom Logit Processor
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
Usage:
|
Usage:
|
||||||
python hidden_states.py
|
python hidden_states.py
|
||||||
|
|
||||||
Note that we are actively working on moving return_hidden_states to the sampling_params.
|
Note that each time you change the `return_hidden_states` parameter,
|
||||||
|
the cuda graph will be recaptured, which might lead to a performance hit.
|
||||||
|
So avoid getting hidden states and completions alternately.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
@@ -18,10 +20,14 @@ def main():
|
|||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = sgl.Engine(
|
llm = sgl.Engine(
|
||||||
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
return_hidden_states=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 10}
|
sampling_params = {
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"max_new_tokens": 10,
|
||||||
|
"return_hidden_states": True,
|
||||||
|
}
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
|||||||
@@ -607,9 +607,6 @@ class ScheduleBatch:
|
|||||||
# Enable custom logit processor
|
# Enable custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
|
|
||||||
# Return hidden states
|
|
||||||
return_hidden_states: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -621,7 +618,6 @@ class ScheduleBatch:
|
|||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
enable_custom_logit_processor: bool,
|
enable_custom_logit_processor: bool,
|
||||||
return_hidden_states: bool = False,
|
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -636,7 +632,6 @@ class ScheduleBatch:
|
|||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||||
return_hidden_states=return_hidden_states,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1205,7 +1200,7 @@ class ScheduleBatch:
|
|||||||
spec_info=self.spec_info,
|
spec_info=self.spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
CaptureHiddenMode.FULL
|
CaptureHiddenMode.FULL
|
||||||
if self.return_hidden_states
|
if self.sampling_info.return_hidden_states
|
||||||
else (
|
else (
|
||||||
getattr(
|
getattr(
|
||||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
|
|||||||
@@ -1030,7 +1030,6 @@ class Scheduler:
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
self.server_args.return_hidden_states,
|
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -1221,9 +1220,8 @@ class Scheduler:
|
|||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.server_args.return_hidden_states
|
req.sampling_params.return_hidden_states
|
||||||
and logits_output.hidden_states is not None
|
and logits_output.hidden_states is not None
|
||||||
):
|
):
|
||||||
req.hidden_states.append(
|
req.hidden_states.append(
|
||||||
@@ -1331,7 +1329,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.server_args.return_hidden_states
|
req.sampling_params.return_hidden_states
|
||||||
and logits_output.hidden_states is not None
|
and logits_output.hidden_states is not None
|
||||||
):
|
):
|
||||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||||
@@ -1459,7 +1457,10 @@ class Scheduler:
|
|||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
cached_tokens = []
|
cached_tokens = []
|
||||||
spec_verify_ct = []
|
spec_verify_ct = []
|
||||||
output_hidden_states = [] if self.server_args.return_hidden_states else None
|
return_hidden_states = any(
|
||||||
|
req.sampling_params.return_hidden_states for req in reqs
|
||||||
|
)
|
||||||
|
output_hidden_states = [] if return_hidden_states else None
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
input_token_logprobs_val = []
|
input_token_logprobs_val = []
|
||||||
@@ -1526,7 +1527,7 @@ class Scheduler:
|
|||||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||||
|
|
||||||
if self.server_args.return_hidden_states:
|
if req.sampling_params.return_hidden_states:
|
||||||
output_hidden_states.append(req.hidden_states)
|
output_hidden_states.append(req.hidden_states)
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
@@ -1619,7 +1620,6 @@ class Scheduler:
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
self.server_args.return_hidden_states,
|
|
||||||
)
|
)
|
||||||
idle_batch.prepare_for_idle()
|
idle_batch.prepare_for_idle()
|
||||||
return idle_batch
|
return idle_batch
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
|
|
||||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||||
# is very samll. We add more values here to make sure we capture the maximum bs.
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||||
capture_bs = list(
|
capture_bs = list(
|
||||||
sorted(
|
sorted(
|
||||||
set(
|
set(
|
||||||
@@ -175,6 +175,7 @@ class CudaGraphRunner:
|
|||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||||
self.capture_forward_mode = ForwardMode.DECODE
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||||
self.num_tokens_per_bs = 1
|
self.num_tokens_per_bs = 1
|
||||||
if model_runner.spec_algorithm.is_eagle():
|
if model_runner.spec_algorithm.is_eagle():
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
@@ -335,6 +336,10 @@ class CudaGraphRunner:
|
|||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
|
|
||||||
spec_info = self.get_spec_info(num_tokens)
|
spec_info = self.get_spec_info(num_tokens)
|
||||||
|
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
||||||
|
self.capture_hidden_mode = (
|
||||||
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
|
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
forward_mode=self.capture_forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
@@ -355,15 +360,7 @@ class CudaGraphRunner:
|
|||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=self.capture_hidden_mode,
|
||||||
CaptureHiddenMode.FULL
|
|
||||||
if self.model_runner.server_args.return_hidden_states
|
|
||||||
else (
|
|
||||||
spec_info.capture_hidden_mode
|
|
||||||
if spec_info
|
|
||||||
else CaptureHiddenMode.NULL
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
@@ -406,6 +403,23 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
|
hidden_mode_from_spec_info = getattr(
|
||||||
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
|
)
|
||||||
|
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||||
|
if (
|
||||||
|
forward_batch.sampling_info.return_hidden_states
|
||||||
|
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||||
|
):
|
||||||
|
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
self.capture()
|
||||||
|
elif (
|
||||||
|
not forward_batch.sampling_info.return_hidden_states
|
||||||
|
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
||||||
|
):
|
||||||
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||||
|
self.capture()
|
||||||
|
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ class SamplingBatchInfo:
|
|||||||
# Whether any request has custom logit processor
|
# Whether any request has custom logit processor
|
||||||
has_custom_logit_processor: bool
|
has_custom_logit_processor: bool
|
||||||
|
|
||||||
|
# Whether any request needs to return hidden states
|
||||||
|
return_hidden_states: bool
|
||||||
|
|
||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
@@ -91,6 +94,9 @@ class SamplingBatchInfo:
|
|||||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if any request needs to return hidden states
|
||||||
|
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
|
||||||
|
|
||||||
if has_custom_logit_processor:
|
if has_custom_logit_processor:
|
||||||
# Merge the same type of custom logit processors together
|
# Merge the same type of custom logit processors together
|
||||||
processor_dict = {}
|
processor_dict = {}
|
||||||
@@ -130,6 +136,7 @@ class SamplingBatchInfo:
|
|||||||
device=device,
|
device=device,
|
||||||
custom_params=custom_params,
|
custom_params=custom_params,
|
||||||
custom_logit_processor=merged_custom_logit_processor,
|
custom_logit_processor=merged_custom_logit_processor,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
|
|
||||||
@@ -336,6 +343,10 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Merge the return hidden states flag
|
||||||
|
self.return_hidden_states |= other.return_hidden_states
|
||||||
|
|
||||||
# Merge the custom logit processors and custom params lists
|
# Merge the custom logit processors and custom params lists
|
||||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||||
# Merge the custom logit processors
|
# Merge the custom logit processors
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class SamplingParams:
|
|||||||
no_stop_trim: bool = False,
|
no_stop_trim: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
custom_params: Optional[Dict[str, Any]] = None,
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
@@ -72,6 +73,7 @@ class SamplingParams:
|
|||||||
self.json_schema = json_schema
|
self.json_schema = json_schema
|
||||||
self.ebnf = ebnf
|
self.ebnf = ebnf
|
||||||
self.no_stop_trim = no_stop_trim
|
self.no_stop_trim = no_stop_trim
|
||||||
|
self.return_hidden_states = return_hidden_states
|
||||||
self.custom_params = custom_params
|
self.custom_params = custom_params
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
|
|||||||
@@ -162,7 +162,6 @@ class ServerArgs:
|
|||||||
delete_ckpt_after_loading: bool = False
|
delete_ckpt_after_loading: bool = False
|
||||||
enable_memory_saver: bool = False
|
enable_memory_saver: bool = False
|
||||||
allow_auto_truncate: bool = False
|
allow_auto_truncate: bool = False
|
||||||
return_hidden_states: bool = False
|
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
tool_call_parser: str = None
|
tool_call_parser: str = None
|
||||||
enable_hierarchical_cache: bool = False
|
enable_hierarchical_cache: bool = False
|
||||||
@@ -917,11 +916,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--return-hidden-states",
|
|
||||||
action="store_true",
|
|
||||||
help="Return hidden states in the response.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tool-call-parser",
|
"--tool-call-parser",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -14,12 +14,15 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
input_ids = tokenizer(prompts).input_ids
|
input_ids = tokenizer(prompts).input_ids
|
||||||
|
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
sampling_params = {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"return_hidden_states": True,
|
||||||
|
}
|
||||||
|
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
return_hidden_states=True,
|
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
|
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
|
||||||
@@ -72,6 +75,58 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_repeatedly_changes_hidden_states(self):
|
||||||
|
prompts = ["Today is", "Today is a sunny day and I like"]
|
||||||
|
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
input_ids = tokenizer(prompts).input_ids
|
||||||
|
|
||||||
|
sample_completion = {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"return_hidden_states": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
sample_hidden_state = {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"return_hidden_states": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=model_path,
|
||||||
|
random_seed=42,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
)
|
||||||
|
outputs_completion_first_round = engine.generate(
|
||||||
|
input_ids=input_ids, sampling_params=sample_completion
|
||||||
|
)
|
||||||
|
outputs_hidden_state = engine.generate(
|
||||||
|
input_ids=input_ids, sampling_params=sample_hidden_state
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_completion_last_round = engine.generate(
|
||||||
|
input_ids=input_ids, sampling_params=sample_completion
|
||||||
|
)
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
for (
|
||||||
|
output_completion_first_round,
|
||||||
|
output_hidden_state,
|
||||||
|
output_completion_last_round,
|
||||||
|
) in zip(
|
||||||
|
outputs_completion_first_round,
|
||||||
|
outputs_hidden_state,
|
||||||
|
outputs_completion_last_round,
|
||||||
|
):
|
||||||
|
self.assertEqual(
|
||||||
|
len(output_completion_first_round["meta_info"]["hidden_states"]), 8
|
||||||
|
)
|
||||||
|
self.assertNotIn("hidden_states", output_hidden_state["meta_info"])
|
||||||
|
self.assertEqual(
|
||||||
|
len(output_completion_last_round["meta_info"]["hidden_states"]), 8
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user