From c578f817ca4c17a076ac7fa93de77db11f008fae Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Thu, 28 Aug 2025 08:57:34 +0800 Subject: [PATCH] [CustomOp] Register VocabParallelEmbedding instead of overwrite forward (#2515) ### What this PR does / why we need it? Register VocabParallelEmbedding instead of overwrite forward ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/644d57d53191b94d9e50a4765891c498790d924b --------- Signed-off-by: Icey <1790571317@qq.com> --- tests/ut/ops/test_vocab_parallel_embedding.py | 259 +++++------------- tests/ut/test_utils.py | 4 +- vllm_ascend/ops/vocab_parallel_embedding.py | 93 ++++--- vllm_ascend/quantization/quantizer.py | 2 +- vllm_ascend/utils.py | 5 + 5 files changed, 122 insertions(+), 241 deletions(-) diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index ff7d060..13ede67 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -13,187 +13,61 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/lora/test_layers.py +import unittest from unittest.mock import MagicMock, patch import torch -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding -from tests.ut.base import TestBase -from vllm_ascend.ops.vocab_parallel_embedding import ( - get_masked_input_and_mask, vocab_parallel_embedding_forward) +from vllm_ascend.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 -class TestGetMaskedInputAndMask(TestBase): +class TestCustomVocabParallelEmbedding(unittest.TestCase): def setUp(self): - self.input_ = torch.arange(12) - - def test_get_masked_input_and_mask(self): - # tp 1 no padding - input_modified, _ = get_masked_input_and_mask( - self.input_, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(self.input_, input_modified) - - # tp 2 no padding - input_rank_0, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=0) - - input_rank_1, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=4, - org_vocab_end_index=8, - added_vocab_start_index=10, - added_vocab_end_index=12, - num_org_vocab_padding=0) - - assert torch.equal(input_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) - assert torch.equal(input_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) - - # tp 4 no padding - input_rank_0, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=0) - - input_rank_1, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=0) - - input_rank_2, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=4, - org_vocab_end_index=6, - added_vocab_start_index=10, - added_vocab_end_index=11, - num_org_vocab_padding=0) - - input_rank_3, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=6, - org_vocab_end_index=8, - added_vocab_start_index=11, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(input_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) - assert torch.equal(input_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) - assert torch.equal(input_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) - assert torch.equal(input_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) - - # tp 1 with padding - input_modified, _ = get_masked_input_and_mask( - self.input_, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal( - input_modified, - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) - - # tp 2 with padding - input_rank_0, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=2) - - input_rank_1, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=4, - org_vocab_end_index=8, - added_vocab_start_index=10, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(input_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) - assert torch.equal(input_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) - - # tp 4 with padding - input_rank_0, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=2) - - input_rank_1, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=2) - - input_rank_2, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=4, - org_vocab_end_index=6, - added_vocab_start_index=10, - added_vocab_end_index=11, - num_org_vocab_padding=2) - - input_rank_3, _ = get_masked_input_and_mask(self.input_, - org_vocab_start_index=6, - org_vocab_end_index=8, - added_vocab_start_index=11, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(input_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) - assert torch.equal(input_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) - assert torch.equal(input_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) - assert torch.equal(input_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) - - -class TestVocabParallelEmbedding(TestBase): - - def setUp(self): - # Create a mock VocabParallelEmbedding instance - self.mock_embedding = MagicMock(spec=VocabParallelEmbedding) - self.mock_embedding.tp_size = 2 # Test with tensor parallelism - self.mock_embedding.shard_indices = MagicMock() - self.mock_embedding.shard_indices.org_vocab_start_index = 10 - self.mock_embedding.shard_indices.org_vocab_end_index = 20 - self.mock_embedding.shard_indices.num_org_vocab_padding = 5 - self.mock_embedding.shard_indices.added_vocab_start_index = 30 - self.mock_embedding.shard_indices.added_vocab_end_index = 40 - self.mock_embedding.quant_method = MagicMock() - - # Set consistent embedding dimension for all tests + self.num_embeddings = 50 self.embedding_dim = 10 - # Mock embedding returns tensor with shape (input_length, embedding_dim) - self.mock_embedding.quant_method.embedding = MagicMock( - side_effect=lambda _, x: torch.randn(x.shape[0], self.embedding_dim - )) + self.org_num_embeddings = 40 + self.padding_size = 8 + + def _create_layer(self): + # Patch methods and dependencies for VocabParallelEmbedding + with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \ + patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \ + patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \ + patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y): + + # Create an instance of VocabParallelEmbedding + layer = AscendVocabParallelEmbedding( + num_embeddings=self.num_embeddings, + embedding_dim=self.embedding_dim, + org_num_embeddings=self.org_num_embeddings, + padding_size=self.padding_size, + quant_config=None, # Mock quantization config + prefix="") + + layer.shard_indices = MagicMock() + layer.shard_indices.org_vocab_start_index = 10 + layer.shard_indices.org_vocab_end_index = 20 + layer.shard_indices.num_org_vocab_padding = 5 + layer.shard_indices.added_vocab_start_index = 30 + layer.shard_indices.added_vocab_end_index = 40 + + # Mock the quantization method + layer.quant_method.embedding = MagicMock( + side_effect=lambda _, x: torch.randn(x.shape[0], self. + embedding_dim)) + return layer def test_get_masked_input_and_mask(self): """Test the mask and offset calculation helper function.""" - input_ = torch.tensor([5, 15, 25, 35, 45]) # includes all cases + layer = self._create_layer() - masked_input, mask = get_masked_input_and_mask( + input_ = torch.tensor([5, 15, 25, 35, 45]) + + masked_input, mask = layer._get_masked_input_and_mask( input_, org_vocab_start_index=10, org_vocab_end_index=20, @@ -201,13 +75,11 @@ class TestVocabParallelEmbedding(TestBase): added_vocab_start_index=30, added_vocab_end_index=40) - # The mask should be True for INVALID tokens (ones we want to mask out) expected_mask = torch.tensor([True, False, True, False, True]) self.assertTrue( torch.equal(mask, expected_mask), f"Mask mismatch. Expected {expected_mask}, got {mask}") - # Check masked input values expected_masked = torch.tensor([0, 5, 0, 20, 0]) self.assertTrue( torch.equal(masked_input, expected_masked), @@ -217,62 +89,64 @@ class TestVocabParallelEmbedding(TestBase): def test_forward_with_tp_size_1(self): """Test forward pass without tensor parallelism.""" # Create a fresh mock embedding with tp_size=1 - mock_embedding = MagicMock(spec=VocabParallelEmbedding) - mock_embedding.tp_size = 1 - mock_embedding.quant_method = MagicMock() - mock_embedding.quant_method.embedding = MagicMock( - return_value=torch.randn(3, self.embedding_dim)) + layer = self._create_layer() + layer.tp_size = 1 + layer.quant_method.embedding = MagicMock( + return_value=torch.randn(3, layer.embedding_dim)) input_ = torch.tensor([1, 2, 3]) with patch( "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x) as mock_reduce_tp1: - output = vocab_parallel_embedding_forward(mock_embedding, input_) + output = layer.forward(input_) # Should just pass through without masking - mock_embedding.quant_method.embedding.assert_called_once_with( - mock_embedding, input_.long()) - self.assertEqual(output.shape, (3, self.embedding_dim)) + layer.quant_method.embedding.assert_called_once_with( + layer, input_.long()) + self.assertEqual(output.shape, (3, layer.embedding_dim)) # Verify all_reduce was called once mock_reduce_tp1.assert_called_once() def test_forward_with_tp(self): - """Test forward pass with tensor parallelism.""" + layer = self._create_layer() + layer.tp_size = 2 + input_ = torch.tensor([15, 35]) # one org vocab, one added vocab + with patch( "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x) as mock_reduce_tp: - output = vocab_parallel_embedding_forward(self.mock_embedding, - input_) + # Call the forward method + output = layer.forward(input_) # Check that masking was applied correctly - self.mock_embedding.quant_method.embedding.assert_called_once() - called_input = self.mock_embedding.quant_method.embedding.call_args[0][ - 1] + layer.quant_method.embedding.assert_called_once() + called_input = layer.quant_method.embedding.call_args[0][1] expected_input = torch.tensor([5, 20]) # after offset calculation self.assertTrue(torch.all(called_input == expected_input)) # Check that all reduce was called - # self.dist_mock.tensor_model_parallel_all_reduce.assert_called_once() mock_reduce_tp.assert_called_once() self.assertEqual(output.shape, (2, self.embedding_dim)) def test_forward_with_invalid_vocab(self): """Test that invalid vocab indices are properly masked out.""" + # Create a fresh embedding layer + layer = self._create_layer() input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases - # Create predictable mock output mock_output = torch.randn(5, self.embedding_dim) - self.mock_embedding.quant_method.embedding = MagicMock( + layer.quant_method.embedding = MagicMock( return_value=mock_output.clone()) + + # Patch tensor_model_parallel_all_reduce to mock its behavior with patch( "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x): - output = vocab_parallel_embedding_forward(self.mock_embedding, - input_) - + # Call the forward method + output = layer.forward(input_) # Check that invalid positions (0, 2, 4) were zeroed out self.assertTrue(torch.all(output[0] == 0)) self.assertTrue(torch.all(output[2] == 0)) @@ -283,6 +157,9 @@ class TestVocabParallelEmbedding(TestBase): def test_output_shape(self): """Test that output shape is correct.""" + # Create a fresh embedding layer + layer = self._create_layer() + test_cases = [ (torch.tensor([15]), (1, self.embedding_dim)), (torch.tensor([15, 35]), (2, self.embedding_dim)), @@ -294,6 +171,6 @@ class TestVocabParallelEmbedding(TestBase): with patch( "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", side_effect=lambda x: x): - output = vocab_parallel_embedding_forward( - self.mock_embedding, input_) + # Call the forward method + output = layer.forward(input_) self.assertEqual(output.shape, expected_shape) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index b144016..396f457 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -289,13 +289,13 @@ class TestUtils(TestBase): # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 9) + self.assertEqual(mock_customop.register_oot.call_count, 10) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 9) + self.assertEqual(mock_customop.register_oot.call_count, 10) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index c35d2f4..05b08a4 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -23,52 +23,51 @@ from vllm.model_executor.layers.vocab_parallel_embedding import \ VocabParallelEmbedding -def get_masked_input_and_mask( - input_: torch.Tensor, org_vocab_start_index: int, - org_vocab_end_index: int, num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: - # torch.compile will fuse all of the pointwise ops below - # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & ( - input_ < org_vocab_end_index) - # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. - if added_vocab_start_index == added_vocab_end_index: - valid_offset = (org_vocab_start_index * org_vocab_mask) - vocab_mask = org_vocab_mask - else: - added_vocab_mask = (input_ >= added_vocab_start_index) & ( - input_ < added_vocab_end_index) - added_offset = added_vocab_start_index - ( - org_vocab_end_index - - org_vocab_start_index) - num_org_vocab_padding - valid_offset = (org_vocab_start_index * - org_vocab_mask) + (added_offset * added_vocab_mask) - vocab_mask = org_vocab_mask | added_vocab_mask - # Adapt end. - input_ = vocab_mask * (input_ - valid_offset) - return input_, ~vocab_mask +class AscendVocabParallelEmbedding(VocabParallelEmbedding): + def _get_masked_input_and_mask( + self, input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) + # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index. + if added_vocab_start_index == added_vocab_end_index: + valid_offset = (org_vocab_start_index * org_vocab_mask) + vocab_mask = org_vocab_mask + else: + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - + org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + # Adapt end. + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask -def vocab_parallel_embedding_forward(self, input_): - if self.tp_size > 1: - # Build the mask. - masked_input, input_mask = get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, - self.shard_indices.org_vocab_end_index, - self.shard_indices.num_org_vocab_padding, - self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) - else: - masked_input = input_ - # Get the embeddings. - output_parallel = self.quant_method.embedding(self, masked_input.long()) - # Mask the output embedding. - if self.tp_size > 1: - output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - return output - - -VocabParallelEmbedding.forward = vocab_parallel_embedding_forward + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = self._get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 988f8bd..0e15ed2 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -78,7 +78,7 @@ class VLLMAscendQuantizer: "vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot", [wrapper_rmsnorm_forward_oot]) VLLMAscendQuantizer.apply_patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding", + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding", "__init__", [wrapper_vocab_parallel_embedding_init]) break VLLMAscendQuantizer.patched = True diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1273805..a99a491 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -512,6 +512,11 @@ def register_ascend_customop(): from vllm_ascend.ops.common_fused_moe import AscendFusedMoE CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") + from vllm_ascend.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding + CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding, + name="VocabParallelEmbedding") + # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True