diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index cf4d1dc..cdb0908 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -34,6 +34,7 @@ The following table lists the additional configuration options available in vLLM | `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | +| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | The details of each config option are as follows: diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py new file mode 100644 index 0000000..afc22c8 --- /dev/null +++ b/tests/ut/distributed/test_parallel_state.py @@ -0,0 +1,44 @@ +from unittest.mock import MagicMock, patch + +import pytest +from vllm.config import ParallelConfig + +from vllm_ascend.distributed.parallel_state import ( + _LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group, + get_mc2_group, init_ascend_model_parallel) + + +@pytest.fixture +def parallel_config(): + return ParallelConfig(data_parallel_size=2, + tensor_parallel_size=2, + pipeline_parallel_size=2) + + +@pytest.fixture +def mock_distributed(): + with patch('torch.distributed.is_initialized', return_value=True), \ + patch('torch.distributed.get_world_size', return_value=8), \ + patch('torch.distributed.get_backend', return_value='nccl'), \ + patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group: + mock_group.return_value.local_rank = 0 + mock_group.return_value.device_group = MagicMock() + yield + + +def test_init_ascend_model_parallel(mock_distributed, parallel_config): + mock_ascend_config = MagicMock() + mock_ascend_config.lmhead_tensor_parallel_size = 2 + with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ + patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ + patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config): + init_ascend_model_parallel(parallel_config) + + mc2_group = get_mc2_group() + assert mc2_group is not None + lmheadtp_group = get_lmhead_tp_group() + assert lmheadtp_group is not None + + destroy_ascend_model_parallel() + assert _MC2 is None + assert _LMTP is None diff --git a/tests/ut/models/test_deepseek_mtp.py b/tests/ut/models/test_deepseek_mtp.py index 45b6ed5..61fdf98 100644 --- a/tests/ut/models/test_deepseek_mtp.py +++ b/tests/ut/models/test_deepseek_mtp.py @@ -31,6 +31,11 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase): mocker_deepseek_v2_decode_layer = mocker.patch( "vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None) mocker_deepseek_v2_decode_layer.assert_called_once() @@ -83,6 +88,11 @@ class TestCustomDeepSeekMultiTokenPredictor(PytestBase): mocker.patch( "vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) predictor = CustomDeepSeekMultiTokenPredictor( vllm_config=mock_vllm_config) @@ -157,6 +167,11 @@ class TestCustomDeepSeekMTP(PytestBase): return_value=None) mocker.patch("vllm.model_executor.layers.sampler.get_sampler", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp = CustomDeepSeekMTP(vllm_config=vllm_config) return mtp @@ -177,4 +192,4 @@ class TestCustomDeepSeekMTP(PytestBase): output = setup_mtp.forward(input_ids, positions, kv_caches, None, previous_hidden_states, inputs_embeds, spec_step_idx) - assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) \ No newline at end of file + assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index e0c50e8..df14a2a 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -26,7 +26,7 @@ from vllm_ascend.models.deepseek_v2 import ( CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, CustomDeepseekV2RowParallelLinearReplaceAllreduce, - CustomDeepseekV2SiluAndMul) + CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead) @pytest.fixture @@ -266,3 +266,30 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, kv_lora_rank=16, prefix="layers.1.self_attn") assert hasattr(attn, "q_proj") + + +def test_deepseek_v2_lmhead(mock_distributed, vllm_config): + # 创建一个简单的配置对象 + class SimpleConfig: + + def __init__(self): + self.vocab_size = 10000 + self.hidden_size = 128 + + config = SimpleConfig() + + # 直接创建lmhead和logits_processor + lmhead = ParallelLMHead(config.vocab_size, config.hidden_size) + logits_processor = LogitsProcessor(config.vocab_size) + + # 创建模拟输出 + mock_output = torch.randn(2, 4, config.hidden_size) + mock_logits = torch.randn(2, 4, config.vocab_size) + + # 直接测试logits_processor + with patch.object(lmhead.quant_method, "apply", return_value=mock_logits): + with patch.object(logits_processor, + "_gather_logits", + return_value=mock_logits): + logits = logits_processor(lmhead, mock_output) + assert logits.shape == (2, 4, config.vocab_size) diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 13ede67..5378b19 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -18,8 +18,8 @@ from unittest.mock import MagicMock, patch import torch -from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding +from vllm_ascend.ops.vocab_parallel_embedding import ( + AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 @@ -34,7 +34,11 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): def _create_layer(self): # Patch methods and dependencies for VocabParallelEmbedding - with patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \ + mock_group = MagicMock() + mock_group.world_size = 2 + mock_group.rank_in_group = 0 + with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \ + patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \ patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \ patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \ patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y): @@ -174,3 +178,55 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): # Call the forward method output = layer.forward(input_) self.assertEqual(output.shape, expected_shape) + + +class TestAscendLogitsProcessor(unittest.TestCase): + + def setUp(self): + self.vocab_size = 50 + self.num_embeddings = 50 + self.embedding_dim = 10 + self.org_num_embeddings = 40 + self.padding_size = 8 + + self.mock_group = MagicMock() + self.mock_group.world_size = 2 + self.mock_group.rank_in_group = 0 + self.mock_ascend_config = MagicMock() + self.mock_quant_method = MagicMock() + self.mock_quant_method.apply = MagicMock( + return_value=torch.randn(1, self.vocab_size)) + self.patches = [ + patch("vllm_ascend.ascend_config.get_ascend_config", + return_value=self.mock_ascend_config), + patch( + "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group", + return_value=self.mock_group), + patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable", + return_value=True), + patch( + "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all", + return_value=torch.randn(1, self.vocab_size)) + ] + + for p in self.patches: + p.start() + + def tearDown(self): + for p in self.patches: + p.stop() + + def test_create_processor(self): + processor = AscendLogitsProcessor(vocab_size=self.vocab_size) + self.assertEqual(processor.vocab_size, self.vocab_size) + + def test_get_logits(self): + processor = AscendLogitsProcessor(vocab_size=self.vocab_size) + lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings, + embedding_dim=self.embedding_dim, + prefix="lm_head") + lmhead.quant_method = self.mock_quant_method + lmhead.quant_method.apply = self.mock_quant_method.apply + hidden_state = torch.randn(1, self.org_num_embeddings) + processor._get_logits(hidden_state, lmhead) + self.mock_quant_method.apply.assert_called_once() diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 49b9abe..622b751 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -16,7 +16,7 @@ import os from transformers import PretrainedConfig -from vllm.config import ModelConfig, VllmConfig +from vllm.config import ModelConfig, ParallelConfig, VllmConfig from tests.ut.base import TestBase from vllm_ascend.ascend_config import (_check_torchair_supported, @@ -75,7 +75,7 @@ class TestAscendConfig(TestBase): "enabled": True }, "expert_map_path": "test_expert_map_path", - "refresh": True + "refresh": True, } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") @@ -304,3 +304,12 @@ class TestAscendConfig(TestBase): "refresh": True } init_ascend_config(test_vllm_config) + + with self.assertRaises(AssertionError): + test_vllm_config.additional_config = { + "lmhead_tensor_parallel_size": 2, + "refresh": True + } + test_vllm_config.parallel_config = ParallelConfig( + data_parallel_size=4, tensor_parallel_size=2) + init_ascend_config(test_vllm_config) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 396f457..0d264c7 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -289,13 +289,13 @@ class TestUtils(TestBase): # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 10) + self.assertEqual(mock_customop.register_oot.call_count, 12) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 10) + self.assertEqual(mock_customop.register_oot.call_count, 12) class TestProfileExecuteDuration(TestBase): diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py index 1c1e6c7..7aafdfc 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -31,6 +31,11 @@ class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase): mocker_deepseek_v2_decode_layer = mocker.patch( "vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None) mocker_deepseek_v2_decode_layer.assert_called_once() @@ -83,6 +88,11 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): mocker.patch( "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) predictor = TorchairDeepSeekMultiTokenPredictor( vllm_config=mock_vllm_config) @@ -157,6 +167,11 @@ class TestTorchairDeepSeekMTP(PytestBase): return_value=None) mocker.patch("vllm.model_executor.layers.sampler.get_sampler", return_value=None) + mocker.patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm_ascend.utils.get_ascend_config", + return_value=mocker.Mock()) mtp = TorchairDeepSeekMTP(vllm_config=vllm_config) return mtp @@ -177,4 +192,4 @@ class TestTorchairDeepSeekMTP(PytestBase): output = setup_mtp.forward(input_ids, positions, kv_caches, None, previous_hidden_states, inputs_embeds, spec_step_idx) - assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) \ No newline at end of file + assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 81cf177..2a2ac7b 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -51,6 +51,16 @@ class AscendConfig: "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel self.enable_prefetch = additional_config.get("enable_prefetch", False) + self.lmhead_tensor_parallel_size = additional_config.get( + "lmhead_tensor_parallel_size", None) + if self.lmhead_tensor_parallel_size is not None: + logger.info( + f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario" + ) + if vllm_config.parallel_config.tensor_parallel_size != 1: + raise AssertionError( + "lmhead_tensor_parallel_size is only supported in the pure DP scenario" + ) class TorchairGraphConfig: diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index db1c5a8..f81d501 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -6,17 +6,26 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, init_model_parallel_group) import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None _MLP_TP: Optional[GroupCoordinator] = None +_LMTP: Optional[GroupCoordinator] = None + def get_mc2_group() -> GroupCoordinator: assert _MC2 is not None, ("mc2 group is not initialized") return _MC2 +def get_lmhead_tp_group() -> GroupCoordinator: + assert _LMTP is not None, ( + "lm head tensor parallel group is not initialized") + return _LMTP + + def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") return _MLP_TP @@ -65,6 +74,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="mlp_tp") + lmhead_tensor_parallel_size = get_ascend_config( + ).lmhead_tensor_parallel_size + if lmhead_tensor_parallel_size is not None: + group_ranks = [] + global _LMTP + num_lmhead_tensor_parallel_groups: int = (world_size // + lmhead_tensor_parallel_size) + for i in range(num_lmhead_tensor_parallel_groups): + ranks = list( + range(i * lmhead_tensor_parallel_size, + (i + 1) * lmhead_tensor_parallel_size)) + group_ranks.append(ranks) + _LMTP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="lmheadtp") + def get_mlp_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" @@ -86,3 +112,8 @@ def destroy_ascend_model_parallel(): if _MLP_TP: _MLP_TP.destroy() _MLP_TP = None + + global _LMTP + if _LMTP: + _LMTP.destroy() + _LMTP = None diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 05b08a4..7ad35dc 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -15,15 +15,108 @@ # limitations under the License. # -from typing import Tuple +from typing import Optional, Tuple import torch -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from torch import nn +from torch.nn.parameter import Parameter +from vllm.distributed import divide, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import get_tp_group +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod, + VocabParallelEmbedding, pad_vocab_size) +from vllm.model_executor.utils import set_weight_attrs + +from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group +from vllm_ascend.utils import lmhead_tp_enable class AscendVocabParallelEmbedding(VocabParallelEmbedding): + """ + Register VocabParallelEmbedding as a custom op for Ascend. + AscendVocabParallelEmbedding support different communication parallel groups + Added the feature of lmheadTP in pure dp scenario + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + nn.Module.__init__(self) + + if lmhead_tp_enable() and prefix.find("lm_head") != -1: + self.comm_group = get_lmhead_tp_group() + else: + self.comm_group = get_tp_group() + + self.tp_size = self.comm_group.world_size + self.tp_rank = self.comm_group.rank_in_group + + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + self.tp_rank, self.tp_size) + self.embedding_dim = embedding_dim + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) def _get_masked_input_and_mask( self, input_: torch.Tensor, org_vocab_start_index: int, @@ -71,3 +164,91 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): # Reduce across all the model parallel GPUs. output = tensor_model_parallel_all_reduce(output_parallel) return output + + +class AscendParallelLMHead(ParallelLMHead): + """ + Register ParallelLMHead as a custom op for Ascend.""" + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + AscendVocabParallelEmbedding.__init__(self, num_embeddings, + embedding_dim, params_dtype, + org_num_embeddings, padding_size, + quant_config, prefix) + + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + +class AscendLogitsProcessor(LogitsProcessor): + """ + Register LogitsProcessor as a custom op for Ascend. + Added the feature of lmheadTP in pure dp scenario + """ + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + if lmhead_tp_enable(): + return self._get_logits_lmheadtp(hidden_states, lm_head, + embedding_bias) + else: + return self._get_logits_normal(hidden_states, lm_head, + embedding_bias) + + def _get_logits_lmheadtp( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Gather hidden states from all devices in tensor parallel group + gathered_hidden_states = get_lmhead_tp_group().all_gather( + hidden_states, dim=0) + local_logits = lm_head.quant_method.apply(lm_head, + gathered_hidden_states, + bias=embedding_bias) + # Gather logits for tensor parallel + logits = get_lmhead_tp_group().all_to_all(local_logits) + # Remove paddings in vocab (if any) + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits + + def _get_logits_normal( + self, + hidden_states: torch.Tensor, + lm_head: AscendParallelLMHead, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + local_logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + # Gather logits for tensor parallel + logits = self._gather_logits(local_logits) + + # Remove paddings in vocab (if any) + if logits is not None: + logits = logits[..., :self.org_vocab_size] + + return logits diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a99a491..adab490 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -33,6 +33,7 @@ from torch_npu.npu.streams import Event from vllm.logger import logger import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config if TYPE_CHECKING: from vllm.config import VllmConfig @@ -489,6 +490,9 @@ def register_ascend_customop(): AscendMlpRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + from vllm_ascend.ops.vocab_parallel_embedding import ( + AscendLogitsProcessor, AscendParallelLMHead, + AscendVocabParallelEmbedding) CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU") CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul, name="SiluAndMul") @@ -497,6 +501,12 @@ def register_ascend_customop(): CustomOp.register_oot( _decorated_op_cls=AscendDeepseekScalingRotaryEmbedding, name="DeepseekScalingRotaryEmbedding") + CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding, + name="VocabParallelEmbedding") + CustomOp.register_oot(_decorated_op_cls=AscendParallelLMHead, + name="ParallelLMHead") + CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor, + name="LogitsProcessor") if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE: CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear, name="ColumnParallelLinear") @@ -512,11 +522,6 @@ def register_ascend_customop(): from vllm_ascend.ops.common_fused_moe import AscendFusedMoE CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") - from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding - CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding, - name="VocabParallelEmbedding") - # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True @@ -547,3 +552,7 @@ def get_ascend_soc_version(): global _ascend_soc_version assert _ascend_soc_version is not None return _ascend_soc_version + + +def lmhead_tp_enable() -> bool: + return get_ascend_config().lmhead_tensor_parallel_size is not None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6930a1c..15effb7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -90,7 +90,7 @@ from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, - vllm_version_is) + lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch @@ -1277,6 +1277,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + if lmhead_tp_enable(): + max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs + logits_indices = nn.functional.pad( + logits_indices, + (0, max_num_reqs_across_dp - logits_indices.shape[0])) + return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, @@ -1734,11 +1740,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[:self.input_batch.num_reqs] sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) else: + if lmhead_tp_enable() and logits is not None: + logits = logits[:len(spec_decode_metadata.logits_indices)] # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits @@ -2081,6 +2091,18 @@ class NPUModelRunner(LoRAModelRunnerMixin): f"Aclgraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.") + need_dummy_logits = (not self.in_profile_run + and lmhead_tp_enable()) + + if need_dummy_logits: + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + return self.model.compute_logits( + hidden_states[dummy_indices], None) + with set_ascend_forward_context( attn_metadata, self.vllm_config, @@ -2097,6 +2119,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) + if need_dummy_logits: + dummy_compute_logits(hidden_states) + if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) self.drafter.dummy_run( @@ -2105,7 +2130,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): skip_attn=True, num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp) - + if need_dummy_logits: + dummy_compute_logits(hidden_states) return hidden_states @contextmanager diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 120b17a..848da93 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -19,7 +19,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata -from vllm_ascend.utils import ProfileExecuteDuration +from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable class MtpProposer: @@ -235,8 +235,20 @@ class MtpProposer: previous_hidden_states=self. hidden_states[:num_input_tokens], kv_caches=self.runner.kv_caches[-1:]) + + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable(): + if not self.runner.with_prefill: + max_num_reqs_across_dp = num_input_tokens + else: + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + if lmhead_tp_enable() and num_indices < logits.shape[0]: + logits = logits[:num_indices] draft_token_ids = logits.argmax(dim=-1) # [batch_size, 1]