Refactor: Move return_hidden_states to the generate input (#3985)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-03-01 20:51:29 -05:00
committed by GitHub
parent 18bb216c28
commit 40782f05d7
12 changed files with 54 additions and 44 deletions

View File

@@ -17,7 +17,6 @@ class TestHiddenState(unittest.TestCase):
sampling_params = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
engine = sgl.Engine(
@@ -25,7 +24,11 @@ class TestHiddenState(unittest.TestCase):
random_seed=42,
skip_tokenizer_init=True,
)
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
outputs = engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=True,
)
engine.shutdown()
for output in outputs:
@@ -81,16 +84,9 @@ class TestHiddenState(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids
sample_completion = {
sampling_params = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
sample_hidden_state = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": False,
}
engine = sgl.Engine(
@@ -99,14 +95,20 @@ class TestHiddenState(unittest.TestCase):
skip_tokenizer_init=True,
)
outputs_completion_first_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion
input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=True,
)
outputs_hidden_state = engine.generate(
input_ids=input_ids, sampling_params=sample_hidden_state
input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=False,
)
outputs_completion_last_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion
input_ids=input_ids,
sampling_params=sampling_params,
return_hidden_states=True,
)
engine.shutdown()