diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py new file mode 100644 index 0000000..ff7d060 --- /dev/null +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -0,0 +1,299 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/lora/test_layers.py + +from unittest.mock import MagicMock, patch + +import torch +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding + +from tests.ut.base import TestBase +from vllm_ascend.ops.vocab_parallel_embedding import ( + get_masked_input_and_mask, vocab_parallel_embedding_forward) + +VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 + + +class TestGetMaskedInputAndMask(TestBase): + + def setUp(self): + self.input_ = torch.arange(12) + + def test_get_masked_input_and_mask(self): + # tp 1 no padding + input_modified, _ = get_masked_input_and_mask( + self.input_, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(self.input_, input_modified) + + # tp 2 no padding + input_rank_0, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0) + + input_rank_1, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=0) + + assert torch.equal(input_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) + assert torch.equal(input_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + + # tp 4 no padding + input_rank_0, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0) + + input_rank_1, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0) + + input_rank_2, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=0) + + input_rank_3, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(input_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) + assert torch.equal(input_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) + assert torch.equal(input_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) + assert torch.equal(input_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + + # tp 1 with padding + input_modified, _ = get_masked_input_and_mask( + self.input_, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal( + input_modified, + torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + + # tp 2 with padding + input_rank_0, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2) + + input_rank_1, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(input_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) + assert torch.equal(input_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + + # tp 4 with padding + input_rank_0, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2) + + input_rank_1, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2) + + input_rank_2, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=2) + + input_rank_3, _ = get_masked_input_and_mask(self.input_, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(input_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) + assert torch.equal(input_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) + assert torch.equal(input_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) + assert torch.equal(input_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) + + +class TestVocabParallelEmbedding(TestBase): + + def setUp(self): + # Create a mock VocabParallelEmbedding instance + self.mock_embedding = MagicMock(spec=VocabParallelEmbedding) + self.mock_embedding.tp_size = 2 # Test with tensor parallelism + self.mock_embedding.shard_indices = MagicMock() + self.mock_embedding.shard_indices.org_vocab_start_index = 10 + self.mock_embedding.shard_indices.org_vocab_end_index = 20 + self.mock_embedding.shard_indices.num_org_vocab_padding = 5 + self.mock_embedding.shard_indices.added_vocab_start_index = 30 + self.mock_embedding.shard_indices.added_vocab_end_index = 40 + self.mock_embedding.quant_method = MagicMock() + + # Set consistent embedding dimension for all tests + self.embedding_dim = 10 + # Mock embedding returns tensor with shape (input_length, embedding_dim) + self.mock_embedding.quant_method.embedding = MagicMock( + side_effect=lambda _, x: torch.randn(x.shape[0], self.embedding_dim + )) + + def test_get_masked_input_and_mask(self): + """Test the mask and offset calculation helper function.""" + input_ = torch.tensor([5, 15, 25, 35, 45]) # includes all cases + + masked_input, mask = get_masked_input_and_mask( + input_, + org_vocab_start_index=10, + org_vocab_end_index=20, + num_org_vocab_padding=5, + added_vocab_start_index=30, + added_vocab_end_index=40) + + # The mask should be True for INVALID tokens (ones we want to mask out) + expected_mask = torch.tensor([True, False, True, False, True]) + self.assertTrue( + torch.equal(mask, expected_mask), + f"Mask mismatch. Expected {expected_mask}, got {mask}") + + # Check masked input values + expected_masked = torch.tensor([0, 5, 0, 20, 0]) + self.assertTrue( + torch.equal(masked_input, expected_masked), + f"Masked input mismatch. Expected {expected_masked}, got {masked_input}" + ) + + def test_forward_with_tp_size_1(self): + """Test forward pass without tensor parallelism.""" + # Create a fresh mock embedding with tp_size=1 + mock_embedding = MagicMock(spec=VocabParallelEmbedding) + mock_embedding.tp_size = 1 + mock_embedding.quant_method = MagicMock() + mock_embedding.quant_method.embedding = MagicMock( + return_value=torch.randn(3, self.embedding_dim)) + + input_ = torch.tensor([1, 2, 3]) + + with patch( + "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", + side_effect=lambda x: x) as mock_reduce_tp1: + output = vocab_parallel_embedding_forward(mock_embedding, input_) + + # Should just pass through without masking + mock_embedding.quant_method.embedding.assert_called_once_with( + mock_embedding, input_.long()) + self.assertEqual(output.shape, (3, self.embedding_dim)) + + # Verify all_reduce was called once + mock_reduce_tp1.assert_called_once() + + def test_forward_with_tp(self): + """Test forward pass with tensor parallelism.""" + input_ = torch.tensor([15, 35]) # one org vocab, one added vocab + with patch( + "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", + side_effect=lambda x: x) as mock_reduce_tp: + output = vocab_parallel_embedding_forward(self.mock_embedding, + input_) + + # Check that masking was applied correctly + self.mock_embedding.quant_method.embedding.assert_called_once() + called_input = self.mock_embedding.quant_method.embedding.call_args[0][ + 1] + expected_input = torch.tensor([5, 20]) # after offset calculation + self.assertTrue(torch.all(called_input == expected_input)) + + # Check that all reduce was called + # self.dist_mock.tensor_model_parallel_all_reduce.assert_called_once() + mock_reduce_tp.assert_called_once() + self.assertEqual(output.shape, (2, self.embedding_dim)) + + def test_forward_with_invalid_vocab(self): + """Test that invalid vocab indices are properly masked out.""" + input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases + + # Create predictable mock output + mock_output = torch.randn(5, self.embedding_dim) + self.mock_embedding.quant_method.embedding = MagicMock( + return_value=mock_output.clone()) + with patch( + "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", + side_effect=lambda x: x): + output = vocab_parallel_embedding_forward(self.mock_embedding, + input_) + + # Check that invalid positions (0, 2, 4) were zeroed out + self.assertTrue(torch.all(output[0] == 0)) + self.assertTrue(torch.all(output[2] == 0)) + self.assertTrue(torch.all(output[4] == 0)) + self.assertTrue(torch.all(output[1] == mock_output[1])) + self.assertTrue(torch.all(output[3] == mock_output[3])) + self.assertEqual(output.shape, (5, self.embedding_dim)) + + def test_output_shape(self): + """Test that output shape is correct.""" + test_cases = [ + (torch.tensor([15]), (1, self.embedding_dim)), + (torch.tensor([15, 35]), (2, self.embedding_dim)), + (torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)), + ] + + for input_, expected_shape in test_cases: + with self.subTest(input=input_): + with patch( + "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", + side_effect=lambda x: x): + output = vocab_parallel_embedding_forward( + self.mock_embedding, input_) + self.assertEqual(output.shape, expected_shape)