Files
xc-llm-ascend/tests/ut/ops/test_vocab_parallel_embedding.py
lidenghui1110 5a7181569c [feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167)
### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |

example

`--additional_config={"oproj_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00

241 lines
10 KiB
Python

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/lora/test_layers.py
import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
class TestCustomVocabParallelEmbedding(unittest.TestCase):
def setUp(self):
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
mock_group = MagicMock()
mock_group.world_size = 2
mock_group.rank_in_group = 0
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
# Create an instance of VocabParallelEmbedding
layer = AscendVocabParallelEmbedding(
num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
org_num_embeddings=self.org_num_embeddings,
padding_size=self.padding_size,
quant_config=None, # Mock quantization config
prefix="")
layer.shard_indices = MagicMock()
layer.shard_indices.org_vocab_start_index = 10
layer.shard_indices.org_vocab_end_index = 20
layer.shard_indices.num_org_vocab_padding = 5
layer.shard_indices.added_vocab_start_index = 30
layer.shard_indices.added_vocab_end_index = 40
# Mock the quantization method
layer.quant_method.embedding = MagicMock(
side_effect=lambda _, x: torch.randn(x.shape[0], self.
embedding_dim))
return layer
def test_get_masked_input_and_mask(self):
"""Test the mask and offset calculation helper function."""
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45])
masked_input, mask = layer._get_masked_input_and_mask(
input_,
org_vocab_start_index=10,
org_vocab_end_index=20,
num_org_vocab_padding=5,
added_vocab_start_index=30,
added_vocab_end_index=40)
expected_mask = torch.tensor([True, False, True, False, True])
self.assertTrue(
torch.equal(mask, expected_mask),
f"Mask mismatch. Expected {expected_mask}, got {mask}")
expected_masked = torch.tensor([0, 5, 0, 20, 0])
self.assertTrue(
torch.equal(masked_input, expected_masked),
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
)
def test_forward_with_tp_size_1(self):
"""Test forward pass without tensor parallelism."""
# Create a fresh mock embedding with tp_size=1
layer = self._create_layer()
layer.tp_size = 1
layer.quant_method.embedding = MagicMock(
return_value=torch.randn(3, layer.embedding_dim))
input_ = torch.tensor([1, 2, 3])
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp1:
output = layer.forward(input_)
# Should just pass through without masking
layer.quant_method.embedding.assert_called_once_with(
layer, input_.long())
self.assertEqual(output.shape, (3, layer.embedding_dim))
# Verify all_reduce was called once
mock_reduce_tp1.assert_called_once()
def test_forward_with_tp(self):
layer = self._create_layer()
layer.tp_size = 2
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp:
# Call the forward method
output = layer.forward(input_)
# Check that masking was applied correctly
layer.quant_method.embedding.assert_called_once()
called_input = layer.quant_method.embedding.call_args[0][1]
expected_input = torch.tensor([5, 20]) # after offset calculation
self.assertTrue(torch.all(called_input == expected_input))
# Check that all reduce was called
mock_reduce_tp.assert_called_once()
self.assertEqual(output.shape, (2, self.embedding_dim))
def test_forward_with_invalid_vocab(self):
"""Test that invalid vocab indices are properly masked out."""
# Create a fresh embedding layer
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
# Create predictable mock output
mock_output = torch.randn(5, self.embedding_dim)
layer.quant_method.embedding = MagicMock(
return_value=mock_output.clone())
# Patch tensor_model_parallel_all_reduce to mock its behavior
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
# Check that invalid positions (0, 2, 4) were zeroed out
self.assertTrue(torch.all(output[0] == 0))
self.assertTrue(torch.all(output[2] == 0))
self.assertTrue(torch.all(output[4] == 0))
self.assertTrue(torch.all(output[1] == mock_output[1]))
self.assertTrue(torch.all(output[3] == mock_output[3]))
self.assertEqual(output.shape, (5, self.embedding_dim))
def test_output_shape(self):
"""Test that output shape is correct."""
# Create a fresh embedding layer
layer = self._create_layer()
test_cases = [
(torch.tensor([15]), (1, self.embedding_dim)),
(torch.tensor([15, 35]), (2, self.embedding_dim)),
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
]
for input_, expected_shape in test_cases:
with self.subTest(input=input_):
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape)
class TestAscendLogitsProcessor(unittest.TestCase):
def setUp(self):
self.vocab_size = 50
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
self.mock_group = MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
self.mock_ascend_config = MagicMock()
self.mock_quant_method = MagicMock()
self.mock_quant_method.apply = MagicMock(
return_value=torch.randn(1, self.vocab_size))
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
return_value=True),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
return_value=torch.randn(1, self.vocab_size)),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather",
return_value=torch.randn(1, self.vocab_size)),
patch(
"vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config",
return_value=MagicMock(max_num_batched_tokens=1000,
max_model_len=512,
enable_chunked_prefill=False))
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
def test_create_processor(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
self.assertEqual(processor.vocab_size, self.vocab_size)
def test_get_logits(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
prefix="lm_head")
lmhead.quant_method = self.mock_quant_method
lmhead.quant_method.apply = self.mock_quant_method.apply
hidden_state = torch.randn(1, self.org_num_embeddings)
processor._get_logits(hidden_state, lmhead)
self.mock_quant_method.apply.assert_called_once()