[CustomOp] Register VocabParallelEmbedding instead of overwrite forward (#2515)

### What this PR does / why we need it?
Register VocabParallelEmbedding instead of overwrite forward

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.1.1
- vLLM main:
644d57d531

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-08-28 08:57:34 +08:00
committed by GitHub
parent 516e14ae6a
commit c578f817ca
5 changed files with 122 additions and 241 deletions

View File

@@ -13,187 +13,61 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/lora/test_layers.py # Adapted from vllm/tests/lora/test_layers.py
import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import torch import torch
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from tests.ut.base import TestBase from vllm_ascend.ops.vocab_parallel_embedding import \
from vllm_ascend.ops.vocab_parallel_embedding import ( AscendVocabParallelEmbedding
get_masked_input_and_mask, vocab_parallel_embedding_forward)
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
class TestGetMaskedInputAndMask(TestBase): class TestCustomVocabParallelEmbedding(unittest.TestCase):
def setUp(self): def setUp(self):
self.input_ = torch.arange(12) self.num_embeddings = 50
def test_get_masked_input_and_mask(self):
# tp 1 no padding
input_modified, _ = get_masked_input_and_mask(
self.input_,
org_vocab_start_index=0,
org_vocab_end_index=8,
added_vocab_start_index=8,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(self.input_, input_modified)
# tp 2 no padding
input_rank_0, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=0,
org_vocab_end_index=4,
added_vocab_start_index=8,
added_vocab_end_index=10,
num_org_vocab_padding=0)
input_rank_1, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=4,
org_vocab_end_index=8,
added_vocab_start_index=10,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(input_rank_0,
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
assert torch.equal(input_rank_1,
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
# tp 4 no padding
input_rank_0, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=0,
org_vocab_end_index=2,
added_vocab_start_index=8,
added_vocab_end_index=9,
num_org_vocab_padding=0)
input_rank_1, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=2,
org_vocab_end_index=4,
added_vocab_start_index=9,
added_vocab_end_index=10,
num_org_vocab_padding=0)
input_rank_2, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=4,
org_vocab_end_index=6,
added_vocab_start_index=10,
added_vocab_end_index=11,
num_org_vocab_padding=0)
input_rank_3, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=6,
org_vocab_end_index=8,
added_vocab_start_index=11,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(input_rank_0,
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
assert torch.equal(input_rank_1,
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
assert torch.equal(input_rank_2,
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
assert torch.equal(input_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
# tp 1 with padding
input_modified, _ = get_masked_input_and_mask(
self.input_,
org_vocab_start_index=0,
org_vocab_end_index=8,
added_vocab_start_index=8,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(
input_modified,
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
# tp 2 with padding
input_rank_0, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=0,
org_vocab_end_index=4,
added_vocab_start_index=8,
added_vocab_end_index=10,
num_org_vocab_padding=2)
input_rank_1, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=4,
org_vocab_end_index=8,
added_vocab_start_index=10,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(input_rank_0,
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
assert torch.equal(input_rank_1,
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
# tp 4 with padding
input_rank_0, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=0,
org_vocab_end_index=2,
added_vocab_start_index=8,
added_vocab_end_index=9,
num_org_vocab_padding=2)
input_rank_1, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=2,
org_vocab_end_index=4,
added_vocab_start_index=9,
added_vocab_end_index=10,
num_org_vocab_padding=2)
input_rank_2, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=4,
org_vocab_end_index=6,
added_vocab_start_index=10,
added_vocab_end_index=11,
num_org_vocab_padding=2)
input_rank_3, _ = get_masked_input_and_mask(self.input_,
org_vocab_start_index=6,
org_vocab_end_index=8,
added_vocab_start_index=11,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(input_rank_0,
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
assert torch.equal(input_rank_1,
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
assert torch.equal(input_rank_2,
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
assert torch.equal(input_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
class TestVocabParallelEmbedding(TestBase):
def setUp(self):
# Create a mock VocabParallelEmbedding instance
self.mock_embedding = MagicMock(spec=VocabParallelEmbedding)
self.mock_embedding.tp_size = 2 # Test with tensor parallelism
self.mock_embedding.shard_indices = MagicMock()
self.mock_embedding.shard_indices.org_vocab_start_index = 10
self.mock_embedding.shard_indices.org_vocab_end_index = 20
self.mock_embedding.shard_indices.num_org_vocab_padding = 5
self.mock_embedding.shard_indices.added_vocab_start_index = 30
self.mock_embedding.shard_indices.added_vocab_end_index = 40
self.mock_embedding.quant_method = MagicMock()
# Set consistent embedding dimension for all tests
self.embedding_dim = 10 self.embedding_dim = 10
# Mock embedding returns tensor with shape (input_length, embedding_dim) self.org_num_embeddings = 40
self.mock_embedding.quant_method.embedding = MagicMock( self.padding_size = 8
side_effect=lambda _, x: torch.randn(x.shape[0], self.embedding_dim
)) def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
# Create an instance of VocabParallelEmbedding
layer = AscendVocabParallelEmbedding(
num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
org_num_embeddings=self.org_num_embeddings,
padding_size=self.padding_size,
quant_config=None, # Mock quantization config
prefix="")
layer.shard_indices = MagicMock()
layer.shard_indices.org_vocab_start_index = 10
layer.shard_indices.org_vocab_end_index = 20
layer.shard_indices.num_org_vocab_padding = 5
layer.shard_indices.added_vocab_start_index = 30
layer.shard_indices.added_vocab_end_index = 40
# Mock the quantization method
layer.quant_method.embedding = MagicMock(
side_effect=lambda _, x: torch.randn(x.shape[0], self.
embedding_dim))
return layer
def test_get_masked_input_and_mask(self): def test_get_masked_input_and_mask(self):
"""Test the mask and offset calculation helper function.""" """Test the mask and offset calculation helper function."""
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes all cases layer = self._create_layer()
masked_input, mask = get_masked_input_and_mask( input_ = torch.tensor([5, 15, 25, 35, 45])
masked_input, mask = layer._get_masked_input_and_mask(
input_, input_,
org_vocab_start_index=10, org_vocab_start_index=10,
org_vocab_end_index=20, org_vocab_end_index=20,
@@ -201,13 +75,11 @@ class TestVocabParallelEmbedding(TestBase):
added_vocab_start_index=30, added_vocab_start_index=30,
added_vocab_end_index=40) added_vocab_end_index=40)
# The mask should be True for INVALID tokens (ones we want to mask out)
expected_mask = torch.tensor([True, False, True, False, True]) expected_mask = torch.tensor([True, False, True, False, True])
self.assertTrue( self.assertTrue(
torch.equal(mask, expected_mask), torch.equal(mask, expected_mask),
f"Mask mismatch. Expected {expected_mask}, got {mask}") f"Mask mismatch. Expected {expected_mask}, got {mask}")
# Check masked input values
expected_masked = torch.tensor([0, 5, 0, 20, 0]) expected_masked = torch.tensor([0, 5, 0, 20, 0])
self.assertTrue( self.assertTrue(
torch.equal(masked_input, expected_masked), torch.equal(masked_input, expected_masked),
@@ -217,62 +89,64 @@ class TestVocabParallelEmbedding(TestBase):
def test_forward_with_tp_size_1(self): def test_forward_with_tp_size_1(self):
"""Test forward pass without tensor parallelism.""" """Test forward pass without tensor parallelism."""
# Create a fresh mock embedding with tp_size=1 # Create a fresh mock embedding with tp_size=1
mock_embedding = MagicMock(spec=VocabParallelEmbedding) layer = self._create_layer()
mock_embedding.tp_size = 1 layer.tp_size = 1
mock_embedding.quant_method = MagicMock() layer.quant_method.embedding = MagicMock(
mock_embedding.quant_method.embedding = MagicMock( return_value=torch.randn(3, layer.embedding_dim))
return_value=torch.randn(3, self.embedding_dim))
input_ = torch.tensor([1, 2, 3]) input_ = torch.tensor([1, 2, 3])
with patch( with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp1: side_effect=lambda x: x) as mock_reduce_tp1:
output = vocab_parallel_embedding_forward(mock_embedding, input_) output = layer.forward(input_)
# Should just pass through without masking # Should just pass through without masking
mock_embedding.quant_method.embedding.assert_called_once_with( layer.quant_method.embedding.assert_called_once_with(
mock_embedding, input_.long()) layer, input_.long())
self.assertEqual(output.shape, (3, self.embedding_dim)) self.assertEqual(output.shape, (3, layer.embedding_dim))
# Verify all_reduce was called once # Verify all_reduce was called once
mock_reduce_tp1.assert_called_once() mock_reduce_tp1.assert_called_once()
def test_forward_with_tp(self): def test_forward_with_tp(self):
"""Test forward pass with tensor parallelism.""" layer = self._create_layer()
layer.tp_size = 2
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
with patch( with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp: side_effect=lambda x: x) as mock_reduce_tp:
output = vocab_parallel_embedding_forward(self.mock_embedding, # Call the forward method
input_) output = layer.forward(input_)
# Check that masking was applied correctly # Check that masking was applied correctly
self.mock_embedding.quant_method.embedding.assert_called_once() layer.quant_method.embedding.assert_called_once()
called_input = self.mock_embedding.quant_method.embedding.call_args[0][ called_input = layer.quant_method.embedding.call_args[0][1]
1]
expected_input = torch.tensor([5, 20]) # after offset calculation expected_input = torch.tensor([5, 20]) # after offset calculation
self.assertTrue(torch.all(called_input == expected_input)) self.assertTrue(torch.all(called_input == expected_input))
# Check that all reduce was called # Check that all reduce was called
# self.dist_mock.tensor_model_parallel_all_reduce.assert_called_once()
mock_reduce_tp.assert_called_once() mock_reduce_tp.assert_called_once()
self.assertEqual(output.shape, (2, self.embedding_dim)) self.assertEqual(output.shape, (2, self.embedding_dim))
def test_forward_with_invalid_vocab(self): def test_forward_with_invalid_vocab(self):
"""Test that invalid vocab indices are properly masked out.""" """Test that invalid vocab indices are properly masked out."""
# Create a fresh embedding layer
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
# Create predictable mock output # Create predictable mock output
mock_output = torch.randn(5, self.embedding_dim) mock_output = torch.randn(5, self.embedding_dim)
self.mock_embedding.quant_method.embedding = MagicMock( layer.quant_method.embedding = MagicMock(
return_value=mock_output.clone()) return_value=mock_output.clone())
# Patch tensor_model_parallel_all_reduce to mock its behavior
with patch( with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x): side_effect=lambda x: x):
output = vocab_parallel_embedding_forward(self.mock_embedding, # Call the forward method
input_) output = layer.forward(input_)
# Check that invalid positions (0, 2, 4) were zeroed out # Check that invalid positions (0, 2, 4) were zeroed out
self.assertTrue(torch.all(output[0] == 0)) self.assertTrue(torch.all(output[0] == 0))
self.assertTrue(torch.all(output[2] == 0)) self.assertTrue(torch.all(output[2] == 0))
@@ -283,6 +157,9 @@ class TestVocabParallelEmbedding(TestBase):
def test_output_shape(self): def test_output_shape(self):
"""Test that output shape is correct.""" """Test that output shape is correct."""
# Create a fresh embedding layer
layer = self._create_layer()
test_cases = [ test_cases = [
(torch.tensor([15]), (1, self.embedding_dim)), (torch.tensor([15]), (1, self.embedding_dim)),
(torch.tensor([15, 35]), (2, self.embedding_dim)), (torch.tensor([15, 35]), (2, self.embedding_dim)),
@@ -294,6 +171,6 @@ class TestVocabParallelEmbedding(TestBase):
with patch( with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x): side_effect=lambda x: x):
output = vocab_parallel_embedding_forward( # Call the forward method
self.mock_embedding, input_) output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)

View File

@@ -289,13 +289,13 @@ class TestUtils(TestBase):
# ascend custom op is not registered # ascend custom op is not registered
utils.register_ascend_customop() utils.register_ascend_customop()
# should call register_oot three # should call register_oot three
self.assertEqual(mock_customop.register_oot.call_count, 9) self.assertEqual(mock_customop.register_oot.call_count, 10)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
# ascend custom op is already registered # ascend custom op is already registered
utils.register_ascend_customop() utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut # should not register_oot again, thus only called three in this ut
self.assertEqual(mock_customop.register_oot.call_count, 9) self.assertEqual(mock_customop.register_oot.call_count, 10)
class TestProfileExecuteDuration(TestBase): class TestProfileExecuteDuration(TestBase):

View File

@@ -23,52 +23,51 @@ from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding VocabParallelEmbedding
def get_masked_input_and_mask( class AscendVocabParallelEmbedding(VocabParallelEmbedding):
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
if added_vocab_start_index == added_vocab_end_index:
valid_offset = (org_vocab_start_index * org_vocab_mask)
vocab_mask = org_vocab_mask
else:
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index -
org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
# Adapt end.
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def _get_masked_input_and_mask(
self, input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
if added_vocab_start_index == added_vocab_end_index:
valid_offset = (org_vocab_start_index * org_vocab_mask)
vocab_mask = org_vocab_mask
else:
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index -
org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
# Adapt end.
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def vocab_parallel_embedding_forward(self, input_): def forward(self, input_):
if self.tp_size > 1: if self.tp_size > 1:
# Build the mask. # Build the mask.
masked_input, input_mask = get_masked_input_and_mask( masked_input, input_mask = self._get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index, input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index, self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding, self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index, self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index) self.shard_indices.added_vocab_end_index)
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long()) output_parallel = self.quant_method.embedding(self,
# Mask the output embedding. masked_input.long())
if self.tp_size > 1: # Mask the output embedding.
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) if self.tp_size > 1:
# Reduce across all the model parallel GPUs. output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
output = tensor_model_parallel_all_reduce(output_parallel) # Reduce across all the model parallel GPUs.
return output output = tensor_model_parallel_all_reduce(output_parallel)
return output
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward

View File

@@ -78,7 +78,7 @@ class VLLMAscendQuantizer:
"vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot", "vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot",
[wrapper_rmsnorm_forward_oot]) [wrapper_rmsnorm_forward_oot])
VLLMAscendQuantizer.apply_patch( VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding", "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding",
"__init__", [wrapper_vocab_parallel_embedding_init]) "__init__", [wrapper_vocab_parallel_embedding_init])
break break
VLLMAscendQuantizer.patched = True VLLMAscendQuantizer.patched = True

View File

@@ -512,6 +512,11 @@ def register_ascend_customop():
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding,
name="VocabParallelEmbedding")
# NOTE: Keep this at last to ensure all custom actions are registered # NOTE: Keep this at last to ensure all custom actions are registered
_ASCEND_CUSTOMOP_IS_REIGISTERED = True _ASCEND_CUSTOMOP_IS_REIGISTERED = True