Add return hidden state in the native API (#3897)

Co-authored-by: Beichen-Ma <mabeichen12@gmail.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Qiaolin Yu
2025-02-27 01:06:54 -05:00
committed by GitHub
parent 71ed01833d
commit d6898dd253
9 changed files with 112 additions and 34 deletions

View File

@@ -14,12 +14,15 @@ class TestHiddenState(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids
sampling_params = {"temperature": 0, "max_new_tokens": 8}
sampling_params = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
engine = sgl.Engine(
model_path=model_path,
random_seed=42,
return_hidden_states=True,
skip_tokenizer_init=True,
)
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
@@ -72,6 +75,58 @@ class TestHiddenState(unittest.TestCase):
)
)
def test_repeatedly_changes_hidden_states(self):
prompts = ["Today is", "Today is a sunny day and I like"]
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
input_ids = tokenizer(prompts).input_ids
sample_completion = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": True,
}
sample_hidden_state = {
"temperature": 0,
"max_new_tokens": 8,
"return_hidden_states": False,
}
engine = sgl.Engine(
model_path=model_path,
random_seed=42,
skip_tokenizer_init=True,
)
outputs_completion_first_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion
)
outputs_hidden_state = engine.generate(
input_ids=input_ids, sampling_params=sample_hidden_state
)
outputs_completion_last_round = engine.generate(
input_ids=input_ids, sampling_params=sample_completion
)
engine.shutdown()
for (
output_completion_first_round,
output_hidden_state,
output_completion_last_round,
) in zip(
outputs_completion_first_round,
outputs_hidden_state,
outputs_completion_last_round,
):
self.assertEqual(
len(output_completion_first_round["meta_info"]["hidden_states"]), 8
)
self.assertNotIn("hidden_states", output_hidden_state["meta_info"])
self.assertEqual(
len(output_completion_last_round["meta_info"]["hidden_states"]), 8
)
if __name__ == "__main__":
unittest.main()