[Feat]: Add custom lmhead tensor model parallel (#2309)
### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.
performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |
example
`--additional_config={"lmhead_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
This commit is contained in:
44
tests/ut/distributed/test_parallel_state.py
Normal file
44
tests/ut/distributed/test_parallel_state.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parallel_config():
|
||||
return ParallelConfig(data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
with patch('torch.distributed.is_initialized', return_value=True), \
|
||||
patch('torch.distributed.get_world_size', return_value=8), \
|
||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
|
||||
mock_group.return_value.local_rank = 0
|
||||
mock_group.return_value.device_group = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
init_ascend_model_parallel(parallel_config)
|
||||
|
||||
mc2_group = get_mc2_group()
|
||||
assert mc2_group is not None
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
assert lmheadtp_group is not None
|
||||
|
||||
destroy_ascend_model_parallel()
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
@@ -31,6 +31,11 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
@@ -83,6 +88,11 @@ class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
predictor = CustomDeepSeekMultiTokenPredictor(
|
||||
vllm_config=mock_vllm_config)
|
||||
@@ -157,6 +167,11 @@ class TestCustomDeepSeekMTP(PytestBase):
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
|
||||
return mtp
|
||||
@@ -177,4 +192,4 @@ class TestCustomDeepSeekMTP(PytestBase):
|
||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
|
||||
@@ -26,7 +26,7 @@ from vllm_ascend.models.deepseek_v2 import (
|
||||
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
|
||||
CustomDeepseekV2RowParallelLinear,
|
||||
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
CustomDeepseekV2SiluAndMul)
|
||||
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -266,3 +266,30 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
|
||||
|
||||
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
|
||||
# 创建一个简单的配置对象
|
||||
class SimpleConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.vocab_size = 10000
|
||||
self.hidden_size = 128
|
||||
|
||||
config = SimpleConfig()
|
||||
|
||||
# 直接创建lmhead和logits_processor
|
||||
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
# 创建模拟输出
|
||||
mock_output = torch.randn(2, 4, config.hidden_size)
|
||||
mock_logits = torch.randn(2, 4, config.vocab_size)
|
||||
|
||||
# 直接测试logits_processor
|
||||
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
|
||||
with patch.object(logits_processor,
|
||||
"_gather_logits",
|
||||
return_value=mock_logits):
|
||||
logits = logits_processor(lmhead, mock_output)
|
||||
assert logits.shape == (2, 4, config.vocab_size)
|
||||
|
||||
@@ -18,8 +18,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
|
||||
|
||||
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
|
||||
|
||||
@@ -34,7 +34,11 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
def _create_layer(self):
|
||||
# Patch methods and dependencies for VocabParallelEmbedding
|
||||
with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
|
||||
mock_group = MagicMock()
|
||||
mock_group.world_size = 2
|
||||
mock_group.rank_in_group = 0
|
||||
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
|
||||
@@ -174,3 +178,55 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
|
||||
class TestAscendLogitsProcessor(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.vocab_size = 50
|
||||
self.num_embeddings = 50
|
||||
self.embedding_dim = 10
|
||||
self.org_num_embeddings = 40
|
||||
self.padding_size = 8
|
||||
|
||||
self.mock_group = MagicMock()
|
||||
self.mock_group.world_size = 2
|
||||
self.mock_group.rank_in_group = 0
|
||||
self.mock_ascend_config = MagicMock()
|
||||
self.mock_quant_method = MagicMock()
|
||||
self.mock_quant_method.apply = MagicMock(
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
self.patches = [
|
||||
patch("vllm_ascend.ascend_config.get_ascend_config",
|
||||
return_value=self.mock_ascend_config),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
|
||||
return_value=True),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start()
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.patches:
|
||||
p.stop()
|
||||
|
||||
def test_create_processor(self):
|
||||
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
|
||||
self.assertEqual(processor.vocab_size, self.vocab_size)
|
||||
|
||||
def test_get_logits(self):
|
||||
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
|
||||
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
|
||||
embedding_dim=self.embedding_dim,
|
||||
prefix="lm_head")
|
||||
lmhead.quant_method = self.mock_quant_method
|
||||
lmhead.quant_method.apply = self.mock_quant_method.apply
|
||||
hidden_state = torch.randn(1, self.org_num_embeddings)
|
||||
processor._get_logits(hidden_state, lmhead)
|
||||
self.mock_quant_method.apply.assert_called_once()
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import os
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import (_check_torchair_supported,
|
||||
@@ -75,7 +75,7 @@ class TestAscendConfig(TestBase):
|
||||
"enabled": True
|
||||
},
|
||||
"expert_map_path": "test_expert_map_path",
|
||||
"refresh": True
|
||||
"refresh": True,
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
|
||||
@@ -304,3 +304,12 @@ class TestAscendConfig(TestBase):
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"lmhead_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
@@ -289,13 +289,13 @@ class TestUtils(TestBase):
|
||||
# ascend custom op is not registered
|
||||
utils.register_ascend_customop()
|
||||
# should call register_oot three
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 10)
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||
|
||||
# ascend custom op is already registered
|
||||
utils.register_ascend_customop()
|
||||
# should not register_oot again, thus only called three in this ut
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 10)
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
||||
|
||||
|
||||
class TestProfileExecuteDuration(TestBase):
|
||||
|
||||
@@ -31,6 +31,11 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
@@ -83,6 +88,11 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
predictor = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=mock_vllm_config)
|
||||
@@ -157,6 +167,11 @@ class TestTorchairDeepSeekMTP(PytestBase):
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
|
||||
return mtp
|
||||
@@ -177,4 +192,4 @@ class TestTorchairDeepSeekMTP(PytestBase):
|
||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
|
||||
Reference in New Issue
Block a user