[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167)
### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.
performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |
example
`--additional_config={"oproj_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
This commit is contained in:
@@ -16,7 +16,8 @@
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_lora # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_lora_linear # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
"""
|
||||
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
This file is a part of the vllm-ascend project.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
split_tensor_along_last_dim)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""
|
||||
AscendRowParallelLinear is a custom implementation of RowParallelLinear
|
||||
that overrides the forward method to handle Ascend-specific operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
tp_group = get_tp_group().device_group
|
||||
hcomm_info = self.get_hcomm_info(tp_group)
|
||||
self.hcomm_info = hcomm_info
|
||||
super().__init__(*args, **kwargs)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
@staticmethod
|
||||
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group.
|
||||
|
||||
Args:
|
||||
group (ProcessGroup): The process group for which to get the HCCL communication info.
|
||||
|
||||
Returns:
|
||||
str: The HCCL communication name for the given group.
|
||||
"""
|
||||
global _HCOMM_INFO
|
||||
if _HCOMM_INFO is not None:
|
||||
return _HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
_HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
|
||||
else:
|
||||
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return _HCOMM_INFO
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Forward pass for the AscendRowParallelLinear layer.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to the layer.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
The output tensor after applying the linear transformation,
|
||||
and optionally the bias if `return_bias` is True.
|
||||
"""
|
||||
input_parallel = self.calc_input(input_)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
output = self.calc_output(input_parallel)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the input tensor for parallel processing.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): the input tensor to be processed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor split along the last dimension
|
||||
for tensor model parallelism, or the original input if not parallel.
|
||||
"""
|
||||
if self.input_is_parallel:
|
||||
return input_
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
return splitted_input[tp_rank].contiguous()
|
||||
|
||||
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation.
|
||||
|
||||
Args:
|
||||
input_parallel (_type_): the input tensor to be processed in parallel.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after applying the linear transformation
|
||||
and optionally handle communication between tensor model parallel ranks.
|
||||
"""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
return output
|
||||
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
|
||||
logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
|
||||
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear
|
||||
15
vllm_ascend/patch/worker/patch_common/patch_lora.py
Normal file
15
vllm_ascend/patch/worker/patch_common/patch_lora.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import vllm
|
||||
from vllm.lora.utils import _all_lora_classes
|
||||
|
||||
from vllm_ascend.patch.worker.patch_common.patch_lora_embedding import \
|
||||
AscendVocabParallelEmbeddingWithLoRA
|
||||
from vllm_ascend.patch.worker.patch_common.patch_lora_linear import (
|
||||
AscendColumnParallelLinearWithLoRA,
|
||||
AscendMergedColumnParallelLinearWithLoRA, AscendRowParallelLinearWithLoRA)
|
||||
|
||||
_all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||
_all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
_all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
|
||||
_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
|
||||
vllm.lora.utils._all_lora_classes = _all_lora_classes
|
||||
@@ -1,11 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
|
||||
from vllm.lora.utils import _all_lora_classes
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
@@ -22,8 +20,3 @@ class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515)
|
||||
_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes = _all_lora_classes
|
||||
|
||||
52
vllm_ascend/patch/worker/patch_common/patch_lora_linear.py
Normal file
52
vllm_ascend/patch/worker/patch_common/patch_lora_linear.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA)
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
Reference in New Issue
Block a user