[Feat]: Add custom lmhead tensor model parallel (#2309)

### What this PR does / why we need it?
This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.

performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |

example

`--additional_config={"lmhead_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
de533ab2a1

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>
This commit is contained in:
lidenghui1110
2025-08-29 11:41:21 +08:00
committed by GitHub
parent e7ad4a64f4
commit 600b08f754
14 changed files with 458 additions and 22 deletions

View File

@@ -51,6 +51,16 @@ class AscendConfig:
"enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.enable_prefetch = additional_config.get("enable_prefetch", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
if self.lmhead_tensor_parallel_size is not None:
logger.info(
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
)
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise AssertionError(
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
)
class TorchairGraphConfig:

View File

@@ -6,17 +6,26 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP
def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
return _MLP_TP
@@ -65,6 +74,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mlp_tp")
lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
group_ranks = []
global _LMTP
num_lmhead_tensor_parallel_groups: int = (world_size //
lmhead_tensor_parallel_size)
for i in range(num_lmhead_tensor_parallel_groups):
ranks = list(
range(i * lmhead_tensor_parallel_size,
(i + 1) * lmhead_tensor_parallel_size))
group_ranks.append(ranks)
_LMTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="lmheadtp")
def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
@@ -86,3 +112,8 @@ def destroy_ascend_model_parallel():
if _MLP_TP:
_MLP_TP.destroy()
_MLP_TP = None
global _LMTP
if _LMTP:
_LMTP.destroy()
_LMTP = None

View File

@@ -15,15 +15,108 @@
# limitations under the License.
#
from typing import Tuple
from typing import Optional, Tuple
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import divide, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
VocabParallelEmbedding, pad_vocab_size)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
from vllm_ascend.utils import lmhead_tp_enable
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
"""
Register VocabParallelEmbedding as a custom op for Ascend.
AscendVocabParallelEmbedding support different communication parallel groups
Added the feature of lmheadTP in pure dp scenario
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
nn.Module.__init__(self)
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
self.comm_group = get_lmhead_tp_group()
else:
self.comm_group = get_tp_group()
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tp_rank, self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def _get_masked_input_and_mask(
self, input_: torch.Tensor, org_vocab_start_index: int,
@@ -71,3 +164,91 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class AscendParallelLMHead(ParallelLMHead):
"""
Register ParallelLMHead as a custom op for Ascend."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
embedding_dim, params_dtype,
org_num_embeddings, padding_size,
quant_config, prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
class AscendLogitsProcessor(LogitsProcessor):
"""
Register LogitsProcessor as a custom op for Ascend.
Added the feature of lmheadTP in pure dp scenario
"""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if lmhead_tp_enable():
return self._get_logits_lmheadtp(hidden_states, lm_head,
embedding_bias)
else:
return self._get_logits_normal(hidden_states, lm_head,
embedding_bias)
def _get_logits_lmheadtp(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Gather hidden states from all devices in tensor parallel group
gathered_hidden_states = get_lmhead_tp_group().all_gather(
hidden_states, dim=0)
local_logits = lm_head.quant_method.apply(lm_head,
gathered_hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = get_lmhead_tp_group().all_to_all(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
def _get_logits_normal(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
local_logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = self._gather_logits(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits

View File

@@ -33,6 +33,7 @@ from torch_npu.npu.streams import Event
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -489,6 +490,9 @@ def register_ascend_customop():
AscendMlpRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead,
AscendVocabParallelEmbedding)
CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU")
CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul,
name="SiluAndMul")
@@ -497,6 +501,12 @@ def register_ascend_customop():
CustomOp.register_oot(
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
name="DeepseekScalingRotaryEmbedding")
CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding,
name="VocabParallelEmbedding")
CustomOp.register_oot(_decorated_op_cls=AscendParallelLMHead,
name="ParallelLMHead")
CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor,
name="LogitsProcessor")
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear,
name="ColumnParallelLinear")
@@ -512,11 +522,6 @@ def register_ascend_customop():
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding,
name="VocabParallelEmbedding")
# NOTE: Keep this at last to ensure all custom actions are registered
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
@@ -547,3 +552,7 @@ def get_ascend_soc_version():
global _ascend_soc_version
assert _ascend_soc_version is not None
return _ascend_soc_version
def lmhead_tp_enable() -> bool:
return get_ascend_config().lmhead_tensor_parallel_size is not None

View File

@@ -90,7 +90,7 @@ from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
vllm_version_is)
lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -1277,6 +1277,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
if lmhead_tp_enable():
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
logits_indices = nn.functional.pad(
logits_indices,
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp,
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
@@ -1734,11 +1740,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
if lmhead_tp_enable() and logits is not None:
logits = logits[:self.input_batch.num_reqs]
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
if lmhead_tp_enable() and logits is not None:
logits = logits[:len(spec_decode_metadata.logits_indices)]
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
@@ -2081,6 +2091,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())
if need_dummy_logits:
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)
def dummy_compute_logits(hidden_states):
return self.model.compute_logits(
hidden_states[dummy_indices], None)
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
@@ -2097,6 +2119,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
assert isinstance(self.drafter, MtpProposer)
self.drafter.dummy_run(
@@ -2105,7 +2130,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
skip_attn=True,
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
return hidden_states
@contextmanager

View File

@@ -19,7 +19,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import ProfileExecuteDuration
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
class MtpProposer:
@@ -235,8 +235,20 @@ class MtpProposer:
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
if not self.runner.with_prefill:
max_num_reqs_across_dp = num_input_tokens
else:
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
last_token_indices = nn.functional.pad(
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
# [batch_size, 1]