[Feat]: Add custom lmhead tensor model parallel (#2309)

### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.

performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |

example

`--additional_config={"lmhead_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
This commit is contained in:
lidenghui1110
2025-08-29 11:41:21 +08:00
committed by GitHub
parent e7ad4a64f4
commit 600b08f754
14 changed files with 458 additions and 22 deletions

View File

@@ -15,15 +15,108 @@
# limitations under the License.
#
from typing import Tuple
from typing import Optional, Tuple
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import divide, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
VocabParallelEmbedding, pad_vocab_size)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
from vllm_ascend.utils import lmhead_tp_enable
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
"""
Register VocabParallelEmbedding as a custom op for Ascend.
AscendVocabParallelEmbedding support different communication parallel groups
Added the feature of lmheadTP in pure dp scenario
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
nn.Module.__init__(self)
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
self.comm_group = get_lmhead_tp_group()
else:
self.comm_group = get_tp_group()
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tp_rank, self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def _get_masked_input_and_mask(
self, input_: torch.Tensor, org_vocab_start_index: int,
@@ -71,3 +164,91 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class AscendParallelLMHead(ParallelLMHead):
"""
Register ParallelLMHead as a custom op for Ascend."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
embedding_dim, params_dtype,
org_num_embeddings, padding_size,
quant_config, prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
class AscendLogitsProcessor(LogitsProcessor):
"""
Register LogitsProcessor as a custom op for Ascend.
Added the feature of lmheadTP in pure dp scenario
"""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if lmhead_tp_enable():
return self._get_logits_lmheadtp(hidden_states, lm_head,
embedding_bias)
else:
return self._get_logits_normal(hidden_states, lm_head,
embedding_bias)
def _get_logits_lmheadtp(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Gather hidden states from all devices in tensor parallel group
gathered_hidden_states = get_lmhead_tp_group().all_gather(
hidden_states, dim=0)
local_logits = lm_head.quant_method.apply(lm_head,
gathered_hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = get_lmhead_tp_group().all_to_all(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
def _get_logits_normal(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
local_logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = self._gather_logits(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits