Fix the output of hidden states after HTTP requests (#4269)
This commit is contained in:
@@ -33,8 +33,11 @@ class TestHiddenState(unittest.TestCase):
|
||||
|
||||
for output in outputs:
|
||||
self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
|
||||
for hidden_state in output["meta_info"]["hidden_states"]:
|
||||
self.assertIsInstance(hidden_state, torch.Tensor)
|
||||
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||
assert isinstance(output["meta_info"]["hidden_states"][i], list)
|
||||
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||
)
|
||||
# Checks that splicing of the batch was done correctly
|
||||
self.assertGreater(
|
||||
outputs[1]["meta_info"]["hidden_states"][0].shape[0],
|
||||
|
||||
Reference in New Issue
Block a user