[Feat] Add custom Embedding tensor model parallel (#2616)
Similar to #2309 , this PR introduces Embedding tensor model parallel to
achieve decreasing of memory consumption. It support both eager mode and
graph mode.
And this PR refactor module tensor parallel configurations supported in
#2309, #2167, #2120, merge all config into `finegrained_tp_config` in
`additional_config`, including:
`lmhead_tensor_parallel_size`
`oproj_tensor_parallel_size`
`embedding_tensor_parallel_size`
`mlp_tensor_parallel_size`
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -30,8 +30,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, pad_vocab_size)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
|
||||
from vllm_ascend.utils import lmhead_tp_enable
|
||||
from vllm_ascend.distributed.parallel_state import (get_embed_tp_group,
|
||||
get_lmhead_tp_group)
|
||||
from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable
|
||||
|
||||
|
||||
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
@@ -50,9 +51,12 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
if lmhead_tp_enable() and prefix.find("head") != -1:
|
||||
self.forward_type = None
|
||||
if lmhead_tp_enable() and "head" in prefix:
|
||||
self.comm_group = get_lmhead_tp_group()
|
||||
elif embedding_tp_enable() and "embed_tokens" in prefix:
|
||||
self.comm_group = get_embed_tp_group()
|
||||
self.forward_type = "embed_tp"
|
||||
else:
|
||||
self.comm_group = get_tp_group()
|
||||
|
||||
@@ -146,6 +150,28 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
return input_, ~vocab_mask
|
||||
|
||||
def forward(self, input_):
|
||||
if self.forward_type == "embed_tp":
|
||||
return self._forward_embed_tp(input_)
|
||||
else:
|
||||
return self._forward_origin(input_)
|
||||
|
||||
def _forward_embed_tp(self, input_):
|
||||
complete_input = self.comm_group.all_gather(input_, dim=0)
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
complete_input, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
output = output.view(input_.shape[0], -1)
|
||||
return output
|
||||
|
||||
def _forward_origin(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
|
||||
Reference in New Issue
Block a user