Refactor: Move return_hidden_states to the generate input (#3985)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
@@ -17,7 +17,6 @@ class TestHiddenState(unittest.TestCase):
|
||||
sampling_params = {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"return_hidden_states": True,
|
||||
}
|
||||
|
||||
engine = sgl.Engine(
|
||||
@@ -25,7 +24,11 @@ class TestHiddenState(unittest.TestCase):
|
||||
random_seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
|
||||
outputs = engine.generate(
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_hidden_states=True,
|
||||
)
|
||||
engine.shutdown()
|
||||
|
||||
for output in outputs:
|
||||
@@ -81,16 +84,9 @@ class TestHiddenState(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
input_ids = tokenizer(prompts).input_ids
|
||||
|
||||
sample_completion = {
|
||||
sampling_params = {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"return_hidden_states": True,
|
||||
}
|
||||
|
||||
sample_hidden_state = {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"return_hidden_states": False,
|
||||
}
|
||||
|
||||
engine = sgl.Engine(
|
||||
@@ -99,14 +95,20 @@ class TestHiddenState(unittest.TestCase):
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
outputs_completion_first_round = engine.generate(
|
||||
input_ids=input_ids, sampling_params=sample_completion
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_hidden_states=True,
|
||||
)
|
||||
outputs_hidden_state = engine.generate(
|
||||
input_ids=input_ids, sampling_params=sample_hidden_state
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_hidden_states=False,
|
||||
)
|
||||
|
||||
outputs_completion_last_round = engine.generate(
|
||||
input_ids=input_ids, sampling_params=sample_completion
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_hidden_states=True,
|
||||
)
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user