diff --git a/tests/ut/sample/test_rejection_sampler.py b/tests/ut/sample/test_rejection_sampler.py new file mode 100644 index 0000000..b6aaf86 --- /dev/null +++ b/tests/ut/sample/test_rejection_sampler.py @@ -0,0 +1,201 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from unittest.mock import patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.sample.rejection_sampler import ( + expand_batch_to_tokens, expand_pytorch, rejection_greedy_sample_pytorch, + rejection_random_sample_pytorch, sample_recovered_tokens_pytorch) + +# Global constants +PLACEHOLDER_TOKEN_ID = -1 +GREEDY_TEMPERATURE = 0.0 +MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens + + +class TestAscendRejectionSampler(TestBase): + + def test_rejection_greedy_sample_pytorch(self): + """Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token""" + batch_size = 2 + max_spec_len = 3 + output_token_ids = torch.full((batch_size, max_spec_len + 1), + PLACEHOLDER_TOKEN_ID) + + cu_num_draft_tokens = torch.tensor([2, 4]) + draft_token_ids = torch.tensor([10, 11, 20, 21]) + target_argmax = torch.tensor([10, 99, 20, 22]) + bonus_token_ids = torch.tensor([[100], [200]]) + + is_greedy = torch.tensor([True, True]) + + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + ) + + assert output_token_ids[0, 0].item() == 10 + assert output_token_ids[0, 1].item() == 99 + assert output_token_ids[1, 0].item() == 20 + assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID + + def test_rejection_random_sample_pytorch(self): + """Test random rejection sampling: accept based on uniform probability""" + batch_size = 2 + max_spec_len = 3 + output_token_ids = torch.full((batch_size, max_spec_len + 1), + PLACEHOLDER_TOKEN_ID) + + cu_num_draft_tokens = torch.tensor([2, 1]) + draft_token_ids = torch.tensor([1, 0, 2]) + draft_probs = torch.tensor([ + [0.0, 0.6, 0.0, 0.4], # vocab_size=4 + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.5, 0.0, 0.0], + ]) + target_probs = torch.tensor([ + [0.0, 0.8, 0.0, 0.2], + [0.2, 0.1, 0.3, 0.4], + [0.9, 0.1, 0.0, 0.0], + ]) + bonus_token_ids = torch.tensor([[100], [200]]) + recovered_token_ids = torch.tensor([1, 2, 3]) + uniform_probs = torch.tensor([0.7, 0.6, 0.5]) + is_greedy = torch.tensor([False, False]) + vocab_size = 4 + + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=False, + ) + + assert output_token_ids[0, 0].item() == 1 + assert output_token_ids[0, 1].item() == 0 + assert output_token_ids[0, 2].item() == 100 + + def test_expand_pytorch(self): + """Test expand_pytorch functionality""" + input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32) + cu_num_tokens_ptr = torch.tensor([2, 5, 7]) + output_ptr = torch.empty(7, dtype=torch.int32) + + expand_pytorch( + output_ptr, + input_ptr, + cu_num_tokens_ptr, + replace_from=0, + replace_to=0, + MAX_NUM_TOKENS=MAX_SPEC_LEN, + ) + + expected = torch.tensor([10, 10, 20, 20, 20, 30, 30]) + assert torch.equal(output_ptr, expected) + + def test_expand_batch_to_tokens(self): + """Test expand_batch_to_tokens wrapper""" + x = torch.tensor([10, 20, 30]) + cu_num_tokens = torch.tensor([2, 5, 7]) + num_tokens = 7 + + with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch" + ) as mock_kernel: + expand_batch_to_tokens(x, cu_num_tokens, num_tokens) + mock_kernel.assert_called_once() + args = mock_kernel.call_args[0] + assert (args[1] == x).all() + assert (args[2] == cu_num_tokens).all() + + # Run actual function + result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens) + expected = torch.tensor([10, 10, 20, 20, 20, 30, 30]) + assert torch.equal(result, expected) + + def test_sample_recovered_tokens_pytorch_ngram(self): + """Test recovered token sampling under n-gram mode""" + output_token_ids = torch.empty(2, dtype=torch.int32) + cu_num_draft_tokens = torch.tensor([1, 2]) + draft_token_ids = torch.tensor([1, 2]) + draft_probs = None + target_probs = torch.tensor([ + [0.1, 0.2, 0.7], + [0.3, 0.3, 0.4], + ]) + q = torch.tensor([ + [0.1, 0.2, 0.7], + [0.5, 0.4, 0.1], + ]) + vocab_size = 3 + + sample_recovered_tokens_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=True, + ) + + assert output_token_ids[0].item() == 0 + assert output_token_ids[1].item() == 1 + + def test_sample_recovered_tokens_pytorch_autoregressive(self): + """Test recovered token sampling for autoregressive models""" + output_token_ids = torch.empty(2, dtype=torch.int32) + cu_num_draft_tokens = torch.tensor([1, 1]) + draft_token_ids = torch.tensor([0, 1]) + draft_probs = torch.tensor([ + [0.6, 0.1, 0.3], + [0.2, 0.7, 0.1], + ]) + target_probs = torch.tensor([ + [0.8, 0.1, 0.1], + [0.3, 0.6, 0.1], + ]) + q = torch.tensor([ + [0.5, 0.3, 0.2], + [0.1, 0.8, 0.1], + ]) + vocab_size = 3 + + sample_recovered_tokens_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=False, + ) + assert output_token_ids[0].item() == 0