[Feat] flashcomm_v2 optim solution (#3232)
### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
@@ -24,6 +24,7 @@ CustomLinearOp
|
||||
└── CustomRowParallelOp
|
||||
│ ├── MLPRowParallelOp
|
||||
│ ├── OProjRowParallelOp
|
||||
| ├── Flashcomm2OProjRowParallelOp
|
||||
│ ├── MatmulAllreduceRowParallelOp
|
||||
│ └── SequenceRowParallelOp
|
||||
└── CustomReplicatedOp
|
||||
@@ -41,6 +42,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (split_tensor_along_last_dim,
|
||||
@@ -49,9 +51,14 @@ from vllm.distributed import (split_tensor_along_last_dim,
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable, shared_expert_dp_enabled)
|
||||
|
||||
@@ -263,6 +270,135 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
|
||||
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.odp_group = get_flashcomm2_odp_group()
|
||||
self.odp_size = self.odp_group.world_size
|
||||
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
|
||||
get_tp_group().world_size)
|
||||
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
|
||||
self.layer._quant_comm_config = {}
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
return get_flashcomm2_otp_group()
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
return 0
|
||||
return self.comm_group.rank_in_group
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
return 1
|
||||
return self.comm_group.world_size
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer for Flashcomm2.
|
||||
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
|
||||
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
|
||||
"""
|
||||
# Handle input parallelism - split or use as-is
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = self.tp_rank
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# padding for all-to-all
|
||||
forward_context = get_forward_context()
|
||||
num_padding_tokens = forward_context.pad_size
|
||||
if num_padding_tokens > 0:
|
||||
input_parallel = nn.functional.pad(input_parallel,
|
||||
(0, 0, 0, num_padding_tokens))
|
||||
|
||||
def otp_maybe_quant_comm(x):
|
||||
|
||||
# Reorganize the tensor so that the batch id and rank id correspond to each other.
|
||||
chunk_num = len(self.reorgnized_batch_ids) * len(
|
||||
self.reorgnized_batch_ids[0])
|
||||
batch_size = x.size(0)
|
||||
|
||||
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
|
||||
|
||||
batch_size_per_chunk = batch_size // chunk_num
|
||||
# Indices of reorganized tensor
|
||||
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
|
||||
reorganized_chunks = chunked[self.group_indices]
|
||||
send_buf = reorganized_chunks.flatten(1, 2)
|
||||
|
||||
# all-to-all operation parameters
|
||||
all2all_tp_size = self.odp_size
|
||||
local_intermediate_size = x.size(1)
|
||||
chunk_size = x.size(0) // all2all_tp_size
|
||||
total_intermediate_size = local_intermediate_size * all2all_tp_size
|
||||
|
||||
# Create receive buffer
|
||||
recv_buf = torch.empty(total_intermediate_size * chunk_size,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
# Perform all-to-all communication
|
||||
dist.all_to_all_single(recv_buf,
|
||||
send_buf,
|
||||
group=self.odp_group.device_group)
|
||||
|
||||
return recv_buf.view(all2all_tp_size, chunk_size,
|
||||
-1).transpose(0, 1).reshape(chunk_size, -1)
|
||||
|
||||
if not hasattr(self, "_quant_comm_config"):
|
||||
self.layer._quant_comm_config = {}
|
||||
self.layer._quant_comm_config[
|
||||
"communication_fn"] = otp_maybe_quant_comm
|
||||
actual_quant_method = getattr(self.quant_method, 'quant_method',
|
||||
self.quant_method)
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
|
||||
# Check if w8a8 quantization is enabled. If not, communicate immediately.
|
||||
input_parallel = otp_maybe_quant_comm(input_parallel)
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in TP>1 case)
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
|
||||
if self.tp_size > 1:
|
||||
# flashcomm2 with reduce-scatter
|
||||
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
# flashcomm1 not enabled
|
||||
output = get_tp_group().all_gather(output, 0)
|
||||
if num_padding_tokens > 0:
|
||||
output = output[:-num_padding_tokens]
|
||||
|
||||
# Handle bias return based on configuration
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
self.input_size_per_partition = self.layer.input_size_per_partition
|
||||
|
||||
|
||||
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
_HCOMM_INFO = None
|
||||
|
||||
@@ -487,13 +623,17 @@ def _get_column_parallel_op(
|
||||
def _get_row_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
|
||||
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]]:
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
return MLPRowParallelOp(layer)
|
||||
if "o_proj" in prefix and oproj_tp_enable():
|
||||
return OProjRowParallelOp(layer)
|
||||
if matmul_allreduce_enable():
|
||||
return MatmulAllreduceRowParallelOp(layer)
|
||||
if flashcomm2_enable():
|
||||
if "o_proj" in prefix or "out_proj" in prefix:
|
||||
return Flashcomm2OProjRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
@@ -509,6 +649,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
return None, 0, 1
|
||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
Flashcomm2OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]] = None
|
||||
if direct == "row":
|
||||
|
||||
Reference in New Issue
Block a user