# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # from unittest.mock import patch import torch from tests.ut.base import TestBase from vllm_ascend.sample.rejection_sampler import ( expand_batch_to_tokens, expand_pytorch, rejection_greedy_sample_pytorch, rejection_random_sample_pytorch, sample_recovered_tokens_pytorch) # Global constants PLACEHOLDER_TOKEN_ID = -1 GREEDY_TEMPERATURE = 0.0 MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens class TestAscendRejectionSampler(TestBase): def test_rejection_greedy_sample_pytorch(self): """Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token""" batch_size = 2 max_spec_len = 2 output_token_ids = torch.full((batch_size, max_spec_len + 1), PLACEHOLDER_TOKEN_ID) cu_num_draft_tokens = torch.tensor([2, 4]) num_draft_tokens = [2, 2] draft_token_ids = torch.tensor([10, 11, 20, 21]) target_argmax = torch.tensor([10, 99, 20, 22]) bonus_token_ids = torch.tensor([[100], [200]]) is_greedy = torch.tensor([True, True]) rejection_greedy_sample_pytorch( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, num_draft_tokens, max_spec_len, is_greedy, ) assert output_token_ids[0, 0].item() == 10 assert output_token_ids[0, 1].item() == 99 assert output_token_ids[1, 0].item() == 20 assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID def test_rejection_random_sample_pytorch(self): """Test random rejection sampling: accept based on uniform probability""" batch_size = 2 max_spec_len = 3 output_token_ids = torch.full((batch_size, max_spec_len + 1), PLACEHOLDER_TOKEN_ID) cu_num_draft_tokens = torch.tensor([2, 1]) draft_token_ids = torch.tensor([1, 0, 2]) draft_probs = torch.tensor([ [0.0, 0.6, 0.0, 0.4], # vocab_size=4 [0.1, 0.2, 0.3, 0.4], [0.5, 0.5, 0.0, 0.0], ]) target_probs = torch.tensor([ [0.0, 0.8, 0.0, 0.2], [0.2, 0.1, 0.3, 0.4], [0.9, 0.1, 0.0, 0.0], ]) bonus_token_ids = torch.tensor([[100], [200]]) recovered_token_ids = torch.tensor([1, 2, 3]) uniform_probs = torch.tensor([0.7, 0.6, 0.5]) is_greedy = torch.tensor([False, False]) vocab_size = 4 rejection_random_sample_pytorch( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, bonus_token_ids, recovered_token_ids, uniform_probs, is_greedy, max_spec_len, vocab_size, IS_NGRAM=False, ) assert output_token_ids[0, 0].item() == 1 assert output_token_ids[0, 1].item() == 0 assert output_token_ids[0, 2].item() == 100 def test_expand_pytorch(self): """Test expand_pytorch functionality""" input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32) cu_num_tokens_ptr = torch.tensor([2, 5, 7]) output_ptr = torch.empty(7, dtype=torch.int32) expand_pytorch( output_ptr, input_ptr, cu_num_tokens_ptr, replace_from=0, replace_to=0, MAX_NUM_TOKENS=MAX_SPEC_LEN, ) expected = torch.tensor([10, 10, 20, 20, 20, 30, 30]) assert torch.equal(output_ptr, expected) def test_expand_batch_to_tokens(self): """Test expand_batch_to_tokens wrapper""" x = torch.tensor([10, 20, 30]) cu_num_tokens = torch.tensor([2, 5, 7]) num_tokens = 7 with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch" ) as mock_kernel: expand_batch_to_tokens(x, cu_num_tokens, num_tokens) mock_kernel.assert_called_once() args = mock_kernel.call_args[0] assert (args[1] == x).all() assert (args[2] == cu_num_tokens).all() # Run actual function result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens) expected = torch.tensor([10, 10, 20, 20, 20, 30, 30]) assert torch.equal(result, expected) def test_sample_recovered_tokens_pytorch_ngram(self): """Test recovered token sampling under n-gram mode""" output_token_ids = torch.empty(2, dtype=torch.int32) cu_num_draft_tokens = torch.tensor([1, 2]) draft_token_ids = torch.tensor([1, 2]) draft_probs = None target_probs = torch.tensor([ [0.1, 0.2, 0.7], [0.3, 0.3, 0.4], ]) q = torch.tensor([ [0.1, 0.2, 0.7], [0.5, 0.4, 0.1], ]) vocab_size = 3 sample_recovered_tokens_pytorch( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, q, vocab_size, IS_NGRAM=True, ) assert output_token_ids[0].item() == 0 assert output_token_ids[1].item() == 1 def test_sample_recovered_tokens_pytorch_autoregressive(self): """Test recovered token sampling for autoregressive models""" output_token_ids = torch.empty(2, dtype=torch.int32) cu_num_draft_tokens = torch.tensor([1, 1]) draft_token_ids = torch.tensor([0, 1]) draft_probs = torch.tensor([ [0.6, 0.1, 0.3], [0.2, 0.7, 0.1], ]) target_probs = torch.tensor([ [0.8, 0.1, 0.1], [0.3, 0.6, 0.1], ]) q = torch.tensor([ [0.5, 0.3, 0.2], [0.1, 0.8, 0.1], ]) vocab_size = 3 sample_recovered_tokens_pytorch( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, q, vocab_size, IS_NGRAM=False, ) assert output_token_ids[0].item() == 0