Add return hidden state in the native API (#3897)
Co-authored-by: Beichen-Ma <mabeichen12@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -14,12 +14,15 @@ class TestHiddenState(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
input_ids = tokenizer(prompts).input_ids
|
||||
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
sampling_params = {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"return_hidden_states": True,
|
||||
}
|
||||
|
||||
engine = sgl.Engine(
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
return_hidden_states=True,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params)
|
||||
@@ -72,6 +75,58 @@ class TestHiddenState(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_repeatedly_changes_hidden_states(self):
|
||||
prompts = ["Today is", "Today is a sunny day and I like"]
|
||||
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
input_ids = tokenizer(prompts).input_ids
|
||||
|
||||
sample_completion = {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"return_hidden_states": True,
|
||||
}
|
||||
|
||||
sample_hidden_state = {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"return_hidden_states": False,
|
||||
}
|
||||
|
||||
engine = sgl.Engine(
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
outputs_completion_first_round = engine.generate(
|
||||
input_ids=input_ids, sampling_params=sample_completion
|
||||
)
|
||||
outputs_hidden_state = engine.generate(
|
||||
input_ids=input_ids, sampling_params=sample_hidden_state
|
||||
)
|
||||
|
||||
outputs_completion_last_round = engine.generate(
|
||||
input_ids=input_ids, sampling_params=sample_completion
|
||||
)
|
||||
engine.shutdown()
|
||||
|
||||
for (
|
||||
output_completion_first_round,
|
||||
output_hidden_state,
|
||||
output_completion_last_round,
|
||||
) in zip(
|
||||
outputs_completion_first_round,
|
||||
outputs_hidden_state,
|
||||
outputs_completion_last_round,
|
||||
):
|
||||
self.assertEqual(
|
||||
len(output_completion_first_round["meta_info"]["hidden_states"]), 8
|
||||
)
|
||||
self.assertNotIn("hidden_states", output_hidden_state["meta_info"])
|
||||
self.assertEqual(
|
||||
len(output_completion_last_round["meta_info"]["hidden_states"]), 8
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user