diff --git a/tests/ut/models/test_mla.py b/tests/ut/models/test_mla.py index 0f3e166f..87fedc22 100644 --- a/tests/ut/models/test_mla.py +++ b/tests/ut/models/test_mla.py @@ -211,5 +211,3 @@ class TestAscendMultiHeadLatentAttention(TestBase): output = attn.forward(positions, hidden_states) self.assertEqual(output.shape, (3, self.hidden_size)) - self.assertTrue( - torch.allclose(output, output.view(-1, self.hidden_size)))