[Feat] Add custom Embedding tensor model parallel (#2616)

Similar to #2309 , this PR introduces Embedding tensor model parallel to
achieve decreasing of memory consumption. It support both eager mode and
graph mode.

And this PR refactor module tensor parallel configurations supported in
#2309, #2167, #2120, merge all config into `finegrained_tp_config` in
`additional_config`, including:
`lmhead_tensor_parallel_size`
`oproj_tensor_parallel_size`
`embedding_tensor_parallel_size`
`mlp_tensor_parallel_size`

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
lidenghui1110
2025-12-12 14:41:20 +08:00
committed by GitHub
parent b8a317caac
commit d65fb194d9
9 changed files with 301 additions and 162 deletions

View File

@@ -30,8 +30,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, pad_vocab_size)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
from vllm_ascend.utils import lmhead_tp_enable
from vllm_ascend.distributed.parallel_state import (get_embed_tp_group,
get_lmhead_tp_group)
from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
@@ -50,9 +51,12 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
nn.Module.__init__(self)
if lmhead_tp_enable() and prefix.find("head") != -1:
self.forward_type = None
if lmhead_tp_enable() and "head" in prefix:
self.comm_group = get_lmhead_tp_group()
elif embedding_tp_enable() and "embed_tokens" in prefix:
self.comm_group = get_embed_tp_group()
self.forward_type = "embed_tp"
else:
self.comm_group = get_tp_group()
@@ -146,6 +150,28 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
return input_, ~vocab_mask
def forward(self, input_):
if self.forward_type == "embed_tp":
return self._forward_embed_tp(input_)
else:
return self._forward_origin(input_)
def _forward_embed_tp(self, input_):
complete_input = self.comm_group.all_gather(input_, dim=0)
masked_input, input_mask = self._get_masked_input_and_mask(
complete_input, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
output = output.view(input_.shape[0], -1)
return output
def _forward_origin(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = self._get_masked_input_and_mask(