Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -57,7 +57,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Generate (text generation model)\n",
|
"## Generate (text generation model)\n",
|
||||||
"Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](../references/sampling_params.md)."
|
"Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](https://docs.sglang.ai/backend/sampling_params.html)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ The `/generate` endpoint accepts the following parameters in JSON format. For in
|
|||||||
* `stream`: Whether to stream the output. `bool = False`
|
* `stream`: Whether to stream the output. `bool = False`
|
||||||
* `lora_path`: Path to LoRA weights. `Optional[Union[List[Optional[str]], Optional[str]]] = None`
|
* `lora_path`: Path to LoRA weights. `Optional[Union[List[Optional[str]], Optional[str]]] = None`
|
||||||
* `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below. `Optional[Union[List[Optional[str]], str]] = None`
|
* `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below. `Optional[Union[List[Optional[str]], str]] = None`
|
||||||
|
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False`
|
||||||
|
|
||||||
## Sampling params
|
## Sampling params
|
||||||
|
|
||||||
@@ -55,8 +56,6 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan
|
|||||||
* `ignore_eos`: Don't stop generation when EOS token is sampled. `bool = False`
|
* `ignore_eos`: Don't stop generation when EOS token is sampled. `bool = False`
|
||||||
* `skip_special_tokens`: Remove special tokens during decoding. `bool = True`
|
* `skip_special_tokens`: Remove special tokens during decoding. `bool = True`
|
||||||
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. `Optional[List[Optional[Dict[str, Any]]]] = None`
|
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. `Optional[List[Optional[Dict[str, Any]]]] = None`
|
||||||
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False`
|
|
||||||
|
|
||||||
|
|
||||||
### Custom Logit Processor
|
### Custom Logit Processor
|
||||||
|
|
||||||
|
|||||||
@@ -26,10 +26,11 @@ def main():
|
|||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
"top_p": 0.95,
|
"top_p": 0.95,
|
||||||
"max_new_tokens": 10,
|
"max_new_tokens": 10,
|
||||||
"return_hidden_states": True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
outputs = llm.generate(
|
||||||
|
prompts, sampling_params=sampling_params, return_hidden_states=True
|
||||||
|
)
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
print("===============================")
|
print("===============================")
|
||||||
print(
|
print(
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ class Engine:
|
|||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> Union[Dict, Iterator[Dict]]:
|
) -> Union[Dict, Iterator[Dict]]:
|
||||||
"""
|
"""
|
||||||
@@ -144,6 +145,7 @@ class Engine:
|
|||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
modalities=modalities_list,
|
modalities=modalities_list,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@@ -69,11 +69,15 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
# Session info for continual prompting
|
# Session info for continual prompting
|
||||||
session_params: Optional[Union[List[Dict], Dict]] = None
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
|
|
||||||
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||||
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||||
# Use the processor's `to_str()` method to generate the serialized string.
|
# Use the processor's `to_str()` method to generate the serialized string.
|
||||||
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||||
|
|
||||||
|
# Whether to return hidden states
|
||||||
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (
|
if (
|
||||||
self.text is None and self.input_ids is None and self.input_embeds is None
|
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||||
@@ -218,6 +222,7 @@ class GenerateReqInput:
|
|||||||
if self.custom_logit_processor is not None
|
if self.custom_logit_processor is not None
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
return_hidden_states=self.return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -255,6 +260,9 @@ class TokenizedGenerateReqInput:
|
|||||||
# Use the processor's `to_str()` method to generate the serialized string.
|
# Use the processor's `to_str()` method to generate the serialized string.
|
||||||
custom_logit_processor: Optional[str] = None
|
custom_logit_processor: Optional[str] = None
|
||||||
|
|
||||||
|
# Whether to return hidden states
|
||||||
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
|
|||||||
@@ -236,6 +236,7 @@ class Req:
|
|||||||
input_embeds: Optional[List[List[float]]] = None,
|
input_embeds: Optional[List[List[float]]] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
custom_logit_processor: Optional[str] = None,
|
custom_logit_processor: Optional[str] = None,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
eos_token_ids: Optional[Set[int]] = None,
|
eos_token_ids: Optional[Set[int]] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
@@ -256,7 +257,9 @@ class Req:
|
|||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
|
|
||||||
self.custom_logit_processor = custom_logit_processor
|
self.custom_logit_processor = custom_logit_processor
|
||||||
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
@@ -608,6 +611,9 @@ class ScheduleBatch:
|
|||||||
# Enable custom logit processor
|
# Enable custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
|
|
||||||
|
# Whether to return hidden states
|
||||||
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
@@ -619,6 +625,7 @@ class ScheduleBatch:
|
|||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
enable_custom_logit_processor: bool,
|
enable_custom_logit_processor: bool,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -633,6 +640,7 @@ class ScheduleBatch:
|
|||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||||
|
return_hidden_states=return_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1153,6 +1161,7 @@ class ScheduleBatch:
|
|||||||
self.return_logprob |= other.return_logprob
|
self.return_logprob |= other.return_logprob
|
||||||
self.has_stream |= other.has_stream
|
self.has_stream |= other.has_stream
|
||||||
self.has_grammar |= other.has_grammar
|
self.has_grammar |= other.has_grammar
|
||||||
|
self.return_hidden_states |= other.return_hidden_states
|
||||||
|
|
||||||
if self.spec_info:
|
if self.spec_info:
|
||||||
self.spec_info.merge_batch(other.spec_info)
|
self.spec_info.merge_batch(other.spec_info)
|
||||||
@@ -1201,7 +1210,7 @@ class ScheduleBatch:
|
|||||||
spec_info=self.spec_info,
|
spec_info=self.spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
CaptureHiddenMode.FULL
|
CaptureHiddenMode.FULL
|
||||||
if self.sampling_info.return_hidden_states
|
if self.return_hidden_states
|
||||||
else (
|
else (
|
||||||
getattr(
|
getattr(
|
||||||
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
|
|||||||
@@ -631,6 +631,7 @@ class Scheduler:
|
|||||||
lora_path=recv_req.lora_path,
|
lora_path=recv_req.lora_path,
|
||||||
input_embeds=recv_req.input_embeds,
|
input_embeds=recv_req.input_embeds,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
|
return_hidden_states=recv_req.return_hidden_states,
|
||||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
@@ -947,9 +948,11 @@ class Scheduler:
|
|||||||
if self.running_batch is not None
|
if self.running_batch is not None
|
||||||
else set([])
|
else set([])
|
||||||
)
|
)
|
||||||
|
return_hidden_states = False
|
||||||
# Get requests from the waiting queue to a new prefill batch
|
# Get requests from the waiting queue to a new prefill batch
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
|
if req.return_hidden_states:
|
||||||
|
return_hidden_states = True
|
||||||
if (
|
if (
|
||||||
self.lora_paths
|
self.lora_paths
|
||||||
and len(
|
and len(
|
||||||
@@ -1035,6 +1038,7 @@ class Scheduler:
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
|
return_hidden_states,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -1226,7 +1230,7 @@ class Scheduler:
|
|||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
req.sampling_params.return_hidden_states
|
req.return_hidden_states
|
||||||
and logits_output.hidden_states is not None
|
and logits_output.hidden_states is not None
|
||||||
):
|
):
|
||||||
req.hidden_states.append(
|
req.hidden_states.append(
|
||||||
@@ -1333,10 +1337,7 @@ class Scheduler:
|
|||||||
logits_output.next_token_top_logprobs_idx[i]
|
logits_output.next_token_top_logprobs_idx[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if req.return_hidden_states and logits_output.hidden_states is not None:
|
||||||
req.sampling_params.return_hidden_states
|
|
||||||
and logits_output.hidden_states is not None
|
|
||||||
):
|
|
||||||
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
@@ -1462,10 +1463,7 @@ class Scheduler:
|
|||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
cached_tokens = []
|
cached_tokens = []
|
||||||
spec_verify_ct = []
|
spec_verify_ct = []
|
||||||
return_hidden_states = any(
|
output_hidden_states = None
|
||||||
req.sampling_params.return_hidden_states for req in reqs
|
|
||||||
)
|
|
||||||
output_hidden_states = [] if return_hidden_states else None
|
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
input_token_logprobs_val = []
|
input_token_logprobs_val = []
|
||||||
@@ -1532,7 +1530,9 @@ class Scheduler:
|
|||||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||||
|
|
||||||
if req.sampling_params.return_hidden_states:
|
if req.return_hidden_states:
|
||||||
|
if output_hidden_states is None:
|
||||||
|
output_hidden_states = []
|
||||||
output_hidden_states.append(req.hidden_states)
|
output_hidden_states.append(req.hidden_states)
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
|
|||||||
@@ -383,6 +383,7 @@ class TokenizerManager:
|
|||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
session_params=session_params,
|
session_params=session_params,
|
||||||
custom_logit_processor=obj.custom_logit_processor,
|
custom_logit_processor=obj.custom_logit_processor,
|
||||||
|
return_hidden_states=obj.return_hidden_states,
|
||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
|
|||||||
@@ -408,13 +408,13 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||||
if (
|
if (
|
||||||
forward_batch.sampling_info.return_hidden_states
|
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||||
):
|
):
|
||||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
self.capture()
|
self.capture()
|
||||||
elif (
|
elif (
|
||||||
not forward_batch.sampling_info.return_hidden_states
|
forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||||
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
and self.capture_hidden_mode != hidden_mode_from_spec_info
|
||||||
):
|
):
|
||||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||||
|
|||||||
@@ -37,9 +37,6 @@ class SamplingBatchInfo:
|
|||||||
# Whether any request has custom logit processor
|
# Whether any request has custom logit processor
|
||||||
has_custom_logit_processor: bool
|
has_custom_logit_processor: bool
|
||||||
|
|
||||||
# Whether any request needs to return hidden states
|
|
||||||
return_hidden_states: bool
|
|
||||||
|
|
||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
@@ -94,9 +91,6 @@ class SamplingBatchInfo:
|
|||||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if any request needs to return hidden states
|
|
||||||
return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs)
|
|
||||||
|
|
||||||
if has_custom_logit_processor:
|
if has_custom_logit_processor:
|
||||||
# Merge the same type of custom logit processors together
|
# Merge the same type of custom logit processors together
|
||||||
processor_dict = {}
|
processor_dict = {}
|
||||||
@@ -136,7 +130,6 @@ class SamplingBatchInfo:
|
|||||||
device=device,
|
device=device,
|
||||||
custom_params=custom_params,
|
custom_params=custom_params,
|
||||||
custom_logit_processor=merged_custom_logit_processor,
|
custom_logit_processor=merged_custom_logit_processor,
|
||||||
return_hidden_states=return_hidden_states,
|
|
||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
|
|
||||||
@@ -344,9 +337,6 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge the return hidden states flag
|
|
||||||
self.return_hidden_states |= other.return_hidden_states
|
|
||||||
|
|
||||||
# Merge the custom logit processors and custom params lists
|
# Merge the custom logit processors and custom params lists
|
||||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||||
# Merge the custom logit processors
|
# Merge the custom logit processors
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ class SamplingParams:
|
|||||||
no_stop_trim: bool = False,
|
no_stop_trim: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
return_hidden_states: bool = False,
|
|
||||||
custom_params: Optional[Dict[str, Any]] = None,
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
@@ -75,7 +74,6 @@ class SamplingParams:
|
|||||||
self.ebnf = ebnf
|
self.ebnf = ebnf
|
||||||
self.structural_tag = structural_tag
|
self.structural_tag = structural_tag
|
||||||
self.no_stop_trim = no_stop_trim
|
self.no_stop_trim = no_stop_trim
|
||||||
self.return_hidden_states = return_hidden_states
|
|
||||||
self.custom_params = custom_params
|
self.custom_params = custom_params
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
sampling_params = {
|
sampling_params = {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 8,
|
"max_new_tokens": 8,
|
||||||
"return_hidden_states": True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(
|
||||||
@@ -25,7 +24,11 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
random_seed=42,
|
random_seed=42,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
|
outputs = engine.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_hidden_states=True,
|
||||||
|
)
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
@@ -81,16 +84,9 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
input_ids = tokenizer(prompts).input_ids
|
input_ids = tokenizer(prompts).input_ids
|
||||||
|
|
||||||
sample_completion = {
|
sampling_params = {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 8,
|
"max_new_tokens": 8,
|
||||||
"return_hidden_states": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
sample_hidden_state = {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 8,
|
|
||||||
"return_hidden_states": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(
|
||||||
@@ -99,14 +95,20 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
outputs_completion_first_round = engine.generate(
|
outputs_completion_first_round = engine.generate(
|
||||||
input_ids=input_ids, sampling_params=sample_completion
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_hidden_states=True,
|
||||||
)
|
)
|
||||||
outputs_hidden_state = engine.generate(
|
outputs_hidden_state = engine.generate(
|
||||||
input_ids=input_ids, sampling_params=sample_hidden_state
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_hidden_states=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs_completion_last_round = engine.generate(
|
outputs_completion_last_round = engine.generate(
|
||||||
input_ids=input_ids, sampling_params=sample_completion
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_hidden_states=True,
|
||||||
)
|
)
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user