Similar to #2309 , this PR introduces Embedding tensor model parallel to
achieve decreasing of memory consumption. It support both eager mode and
graph mode.
And this PR refactor module tensor parallel configurations supported in
#2309, #2167, #2120, merge all config into `finegrained_tp_config` in
`additional_config`, including:
`lmhead_tensor_parallel_size`
`oproj_tensor_parallel_size`
`embedding_tensor_parallel_size`
`mlp_tensor_parallel_size`
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
282 lines
12 KiB
Python
282 lines
12 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.parameter import Parameter
|
|
from vllm.distributed import divide
|
|
from vllm.distributed.parallel_state import get_tp_group
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
|
|
VocabParallelEmbedding, pad_vocab_size)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
from vllm_ascend.distributed.parallel_state import (get_embed_tp_group,
|
|
get_lmhead_tp_group)
|
|
from vllm_ascend.utils import embedding_tp_enable, lmhead_tp_enable
|
|
|
|
|
|
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
|
"""
|
|
Register VocabParallelEmbedding as a custom op for Ascend.
|
|
AscendVocabParallelEmbedding support different communication parallel groups
|
|
Added the feature of lmheadTP in pure dp scenario
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
org_num_embeddings: Optional[int] = None,
|
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = ""):
|
|
nn.Module.__init__(self)
|
|
self.forward_type = None
|
|
if lmhead_tp_enable() and "head" in prefix:
|
|
self.comm_group = get_lmhead_tp_group()
|
|
elif embedding_tp_enable() and "embed_tokens" in prefix:
|
|
self.comm_group = get_embed_tp_group()
|
|
self.forward_type = "embed_tp"
|
|
else:
|
|
self.comm_group = get_tp_group()
|
|
|
|
self.tp_size = self.comm_group.world_size
|
|
self.tp_rank = self.comm_group.rank_in_group
|
|
|
|
self.num_embeddings = num_embeddings
|
|
self.padding_size = padding_size
|
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
|
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
|
self.padding_size)
|
|
self.num_embeddings_padded = pad_vocab_size(
|
|
self.org_vocab_size_padded + num_added_embeddings,
|
|
self.padding_size)
|
|
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
|
|
|
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
|
self.org_vocab_size_padded,
|
|
self.num_embeddings,
|
|
self.org_vocab_size,
|
|
self.tp_rank, self.tp_size)
|
|
self.embedding_dim = embedding_dim
|
|
quant_method = None
|
|
if quant_config is not None:
|
|
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
|
if quant_method is None:
|
|
quant_method = UnquantizedEmbeddingMethod()
|
|
|
|
# If we are making an embedding layer, then our quantization linear
|
|
# method must implement the embedding operation. If we are another
|
|
# layer type like ParallelLMHead, this is not important.
|
|
is_embedding_layer = type(self) is VocabParallelEmbedding
|
|
quant_method_implements_embedding = method_has_implemented_embedding(
|
|
type(quant_method))
|
|
if is_embedding_layer and not quant_method_implements_embedding:
|
|
raise NotImplementedError(
|
|
f"The class {type(quant_method).__name__} must implement "
|
|
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
|
|
|
self.quant_method: QuantizeMethodBase = quant_method
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
self.params_dtype = params_dtype
|
|
# Divide the weight matrix along the vocaburaly dimension.
|
|
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
|
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
|
self.tp_size)
|
|
assert (self.shard_indices.num_elements_padded ==
|
|
self.num_embeddings_per_partition)
|
|
self.num_org_embeddings_per_partition = (
|
|
self.shard_indices.org_vocab_end_index -
|
|
self.shard_indices.org_vocab_start_index)
|
|
self.num_added_embeddings_per_partition = (
|
|
self.shard_indices.added_vocab_end_index -
|
|
self.shard_indices.added_vocab_start_index)
|
|
|
|
self.quant_method.create_weights(self,
|
|
self.embedding_dim,
|
|
[self.num_embeddings_per_partition],
|
|
self.embedding_dim,
|
|
self.num_embeddings_padded,
|
|
params_dtype=params_dtype,
|
|
weight_loader=self.weight_loader)
|
|
|
|
def _get_masked_input_and_mask(
|
|
self, input_: torch.Tensor, org_vocab_start_index: int,
|
|
org_vocab_end_index: int, num_org_vocab_padding: int,
|
|
added_vocab_start_index: int,
|
|
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# torch.compile will fuse all of the pointwise ops below
|
|
# into a single kernel, making it very fast
|
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
|
input_ < org_vocab_end_index)
|
|
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
|
|
if added_vocab_start_index == added_vocab_end_index:
|
|
valid_offset = (org_vocab_start_index * org_vocab_mask)
|
|
vocab_mask = org_vocab_mask
|
|
else:
|
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
|
input_ < added_vocab_end_index)
|
|
added_offset = added_vocab_start_index - (
|
|
org_vocab_end_index -
|
|
org_vocab_start_index) - num_org_vocab_padding
|
|
valid_offset = (org_vocab_start_index *
|
|
org_vocab_mask) + (added_offset * added_vocab_mask)
|
|
vocab_mask = org_vocab_mask | added_vocab_mask
|
|
# Adapt end.
|
|
input_ = vocab_mask * (input_ - valid_offset)
|
|
return input_, ~vocab_mask
|
|
|
|
def forward(self, input_):
|
|
if self.forward_type == "embed_tp":
|
|
return self._forward_embed_tp(input_)
|
|
else:
|
|
return self._forward_origin(input_)
|
|
|
|
def _forward_embed_tp(self, input_):
|
|
complete_input = self.comm_group.all_gather(input_, dim=0)
|
|
masked_input, input_mask = self._get_masked_input_and_mask(
|
|
complete_input, self.shard_indices.org_vocab_start_index,
|
|
self.shard_indices.org_vocab_end_index,
|
|
self.shard_indices.num_org_vocab_padding,
|
|
self.shard_indices.added_vocab_start_index,
|
|
self.shard_indices.added_vocab_end_index)
|
|
# Get the embeddings.
|
|
output_parallel = self.quant_method.embedding(self,
|
|
masked_input.long())
|
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
|
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
|
output = output.view(input_.shape[0], -1)
|
|
return output
|
|
|
|
def _forward_origin(self, input_):
|
|
if self.tp_size > 1:
|
|
# Build the mask.
|
|
masked_input, input_mask = self._get_masked_input_and_mask(
|
|
input_, self.shard_indices.org_vocab_start_index,
|
|
self.shard_indices.org_vocab_end_index,
|
|
self.shard_indices.num_org_vocab_padding,
|
|
self.shard_indices.added_vocab_start_index,
|
|
self.shard_indices.added_vocab_end_index)
|
|
else:
|
|
masked_input = input_
|
|
# Get the embeddings.
|
|
output_parallel = self.quant_method.embedding(self,
|
|
masked_input.long())
|
|
# Mask the output embedding.
|
|
if self.tp_size > 1:
|
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
|
# Reduce across all the model parallel GPUs.
|
|
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
|
return output
|
|
|
|
|
|
class AscendParallelLMHead(ParallelLMHead):
|
|
"""
|
|
Register ParallelLMHead as a custom op for Ascend."""
|
|
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
bias: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
org_num_embeddings: Optional[int] = None,
|
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = ""):
|
|
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
|
|
embedding_dim, params_dtype,
|
|
org_num_embeddings, padding_size,
|
|
quant_config, prefix)
|
|
|
|
self.quant_config = quant_config
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.num_embeddings_per_partition,
|
|
dtype=params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
|
|
class AscendLogitsProcessor(LogitsProcessor):
|
|
"""
|
|
Register LogitsProcessor as a custom op for Ascend.
|
|
Added the feature of lmheadTP in pure dp scenario
|
|
"""
|
|
|
|
def _get_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: AscendParallelLMHead,
|
|
embedding_bias: Optional[torch.Tensor] = None,
|
|
) -> Optional[torch.Tensor]:
|
|
if lmhead_tp_enable():
|
|
return self._get_logits_lmheadtp(hidden_states, lm_head,
|
|
embedding_bias)
|
|
else:
|
|
return self._get_logits_normal(hidden_states, lm_head,
|
|
embedding_bias)
|
|
|
|
def _get_logits_lmheadtp(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: AscendParallelLMHead,
|
|
embedding_bias: Optional[torch.Tensor],
|
|
) -> Optional[torch.Tensor]:
|
|
# Gather hidden states from all devices in tensor parallel group
|
|
gathered_hidden_states = get_lmhead_tp_group().all_gather(
|
|
hidden_states, dim=0)
|
|
local_logits = lm_head.quant_method.apply(lm_head,
|
|
gathered_hidden_states,
|
|
bias=embedding_bias)
|
|
# Gather logits for tensor parallel
|
|
logits = get_lmhead_tp_group().all_to_all(local_logits)
|
|
# Remove paddings in vocab (if any)
|
|
if logits is not None:
|
|
logits = logits[..., :self.org_vocab_size]
|
|
return logits
|
|
|
|
def _get_logits_normal(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: AscendParallelLMHead,
|
|
embedding_bias: Optional[torch.Tensor],
|
|
) -> Optional[torch.Tensor]:
|
|
local_logits = lm_head.quant_method.apply(lm_head,
|
|
hidden_states,
|
|
bias=embedding_bias)
|
|
# Gather logits for tensor parallel
|
|
logits = self._gather_logits(local_logits)
|
|
|
|
# Remove paddings in vocab (if any)
|
|
if logits is not None:
|
|
logits = logits[..., :self.org_vocab_size]
|
|
|
|
return logits
|