diff --git a/tests/ut/sample/test_rejection_sampler.py b/tests/ut/sample/test_rejection_sampler.py index adbf376d..f2f8ac19 100644 --- a/tests/ut/sample/test_rejection_sampler.py +++ b/tests/ut/sample/test_rejection_sampler.py @@ -174,7 +174,7 @@ class TestAscendRejectionSampler(TestBase): def test_sample_recovered_tokens_pytorch_autoregressive(self): """Test recovered token sampling for autoregressive models""" output_token_ids = torch.empty(2, dtype=torch.int32) - cu_num_draft_tokens = torch.tensor([1, 1]) + cu_num_draft_tokens = torch.tensor([1, 2]) draft_token_ids = torch.tensor([0, 1]) draft_probs = torch.tensor([ [0.6, 0.1, 0.3], @@ -201,3 +201,4 @@ class TestAscendRejectionSampler(TestBase): IS_NGRAM=False, ) assert output_token_ids[0].item() == 0 + assert output_token_ids[1].item() == 0