[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167)
### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.
performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |
example
`--additional_config={"oproj_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
This commit is contained in:
@@ -18,27 +18,33 @@ limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import divide, split_tensor_along_last_dim
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.lora.utils import LinearBase
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
QuantizeMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
get_mlp_tensor_model_parallel_rank,
|
||||
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
_HCOMM_INFO = None
|
||||
|
||||
|
||||
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
@@ -59,15 +65,15 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
self.comm_group = None
|
||||
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
||||
self.comm_group = get_mlp_tp_group()
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
self.comm_group = get_tp_group()
|
||||
|
||||
self.tp_size = self.comm_group.world_size
|
||||
self.tp_rank = self.comm_group.rank_in_group
|
||||
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
@@ -77,14 +83,14 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
@@ -114,7 +120,7 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
|
||||
class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
class AscendRowParallelLinear(RowParallelLinear):
|
||||
"""Linear layer with row parallelism.
|
||||
Use the MLP tensor parallelism group in the MLP module,
|
||||
and the original TP group in other modules.
|
||||
@@ -134,27 +140,37 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
if prefix.find("down_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
if prefix.find("down_proj") != -1 and mlp_tp_enable():
|
||||
comm_group = get_mlp_tp_group()
|
||||
self.forward_type = "mlp_tp"
|
||||
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
|
||||
comm_group = get_otp_group()
|
||||
self.forward_type = "oproj_tp"
|
||||
elif matmul_allreduce_enable():
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "matmul_allreduce"
|
||||
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "normal"
|
||||
self.comm_group = comm_group
|
||||
|
||||
self.tp_size = self.comm_group.world_size
|
||||
self.tp_rank = self.comm_group.rank_in_group
|
||||
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
LinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
@@ -184,61 +200,140 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if matmul_allreduce_enable():
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
@staticmethod
|
||||
def get_hcomm_info(group: ProcessGroup) -> str:
|
||||
"""Get the HCCL communication information for the given group."""
|
||||
global _HCOMM_INFO
|
||||
if _HCOMM_INFO is not None:
|
||||
return _HCOMM_INFO
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
_HCOMM_INFO = group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(global_rank)
|
||||
else:
|
||||
_HCOMM_INFO = group.get_hccl_comm_name(rank)
|
||||
return _HCOMM_INFO
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill: bool = True,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.enable_mlp_optimze:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
|
||||
# output = output[:num_tokens,:]
|
||||
# dispose_tensor(output_parallel)
|
||||
# Choose different forward function according to the type of TP group
|
||||
if self.forward_type == "oproj_tp":
|
||||
return self._forward_oproj_tp(input_)
|
||||
elif self.forward_type == "mlp_tp":
|
||||
return self._forward_mlp_tp(input_)
|
||||
elif self.forward_type == "matmul_allreduce":
|
||||
return self._forward_matmul_allreduce(input_)
|
||||
else:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
return super().forward(input_)
|
||||
|
||||
# enable custom MLP tensor parallel
|
||||
def _forward_mlp_tp(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
assert self.quant_method is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0
|
||||
or self.skip_bias_add) else self.bias
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
# enable custom Oproj tensor parallel
|
||||
def _forward_oproj_tp(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Prepare tensors for all-to-all communication
|
||||
local_batch_size = input_parallel.size(0)
|
||||
chunk_size = self.input_size_per_partition
|
||||
total_batch_size = local_batch_size * self.tp_size
|
||||
|
||||
# Reshape tensor for efficient cross-device transfer:
|
||||
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
|
||||
send_buf = (input_parallel.reshape(-1,
|
||||
self.tp_size, chunk_size).transpose(
|
||||
0, 1).contiguous().view(-1))
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_batch_size * chunk_size,
|
||||
dtype=input_parallel.dtype,
|
||||
device=input_parallel.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.comm_group.device_group)
|
||||
input_parallel = recv_buf.view(total_batch_size, chunk_size)
|
||||
|
||||
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
|
||||
# otp-specific: Combine partial results across devices
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
|
||||
# Handle bias return based on configuration
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def _forward_matmul_allreduce(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
"""Calculate the output tensor of forward by considering
|
||||
fusing communication and computation."""
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
||||
self.weight_t,
|
||||
self.hcomm_info,
|
||||
bias=bias_)
|
||||
else:
|
||||
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
@@ -262,48 +357,85 @@ class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
if prefix.find("gate_up_proj") != -1:
|
||||
self.tp_size = get_mlp_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_mlp_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = True
|
||||
self.comm_group = None
|
||||
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
||||
self.comm_group = get_mlp_tp_group()
|
||||
self.forward_type = "mlp_tp"
|
||||
else:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.enable_mlp_optimze = False
|
||||
self.comm_group = get_tp_group()
|
||||
self.forward_type = "normal_tp"
|
||||
self.tp_rank = self.comm_group.rank_in_group
|
||||
self.tp_size = self.comm_group.world_size
|
||||
|
||||
self.output_sizes = output_sizes
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
AscendMlpColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
AscendColumnParallelLinear.__init__(self,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.forward_type == "mlp_tp":
|
||||
return self._forward_mlp_tp(input_)
|
||||
else:
|
||||
return super().forward(input_)
|
||||
|
||||
def _forward_mlp_tp(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
if self.enable_mlp_optimze:
|
||||
input2_ = get_mlp_tp_group().all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self, input2_, bias)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
input_parallel = get_mlp_tp_group().all_gather(input_, 0)
|
||||
output = self.quant_method.apply(self, input_parallel, bias)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class AscendLinearBase(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.disable_tp = disable_tp
|
||||
|
||||
Reference in New Issue
Block a user