Fix the output of hidden states after HTTP requests (#4269)

This commit is contained in:
Qiaolin Yu
2025-03-13 17:54:06 -04:00
committed by GitHub
parent 5fe79605a8
commit 85d2365d33
5 changed files with 47 additions and 9 deletions

View File

@@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase):
for output in outputs:
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
for hidden_state in output["meta_info"]["hidden_states"]:
self.assertIsInstance(hidden_state, torch.Tensor)
for i in range(len(output["meta_info"]["hidden_states"])):
assert isinstance(output["meta_info"]["hidden_states"][i], list)
output["meta_info"]["hidden_states"][i] = torch.tensor(
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
)
# Checks that splicing of the batch was done correctly
self.assertGreater(
outputs[1]["meta_info"]["hidden_states"][0].shape[0],