[Feat] flashcomm_v2 optim solution (#3232)

### What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces
communication overhead, decreases RmsNorm computation, and saves one
AllGather step by replacing Allreduce operations in the Attention module
with pre-AlltoAll and post-AllGather operations (used in combination
with FlashComm1). This feature is enabled during the Prefill phase and
is recommended to be used together with FlashComm1, delivering broad
performance improvements, especially in long sequence scenarios with
large tensor parallelism (TP) configurations. Benchmark tests show that
under TP16DP1 configuration, it can improve the prefill performance of
the DeepSeek model by 8% on top of FlashComm1.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: zzhxx <2783294813@qq.com>
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: zzhxx <2783294813@qq.com>
This commit is contained in:
Levi
2025-11-10 11:01:45 +08:00
committed by GitHub
parent b1a00e0512
commit 0a62e671fb
12 changed files with 380 additions and 24 deletions

View File

@@ -24,6 +24,7 @@ CustomLinearOp
└── CustomRowParallelOp
│ ├── MLPRowParallelOp
│ ├── OProjRowParallelOp
| ├── Flashcomm2OProjRowParallelOp
│ ├── MatmulAllreduceRowParallelOp
│ └── SequenceRowParallelOp
└── CustomReplicatedOp
@@ -41,6 +42,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import (split_tensor_along_last_dim,
@@ -49,9 +51,14 @@ from vllm.distributed import (split_tensor_along_last_dim,
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
@@ -263,6 +270,135 @@ class OProjRowParallelOp(CustomRowParallelOp):
self.input_size_per_partition = self.layer.input_size_per_partition
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
self.odp_group = get_flashcomm2_odp_group()
self.odp_size = self.odp_group.world_size
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
get_tp_group().world_size)
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
self.layer._quant_comm_config = {}
@property
def comm_group(self):
return get_flashcomm2_otp_group()
@property
def tp_rank(self):
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
return 0
return self.comm_group.rank_in_group
@property
def tp_size(self):
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
return 1
return self.comm_group.world_size
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer for Flashcomm2.
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
"""
# Handle input parallelism - split or use as-is
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = self.tp_rank
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# padding for all-to-all
forward_context = get_forward_context()
num_padding_tokens = forward_context.pad_size
if num_padding_tokens > 0:
input_parallel = nn.functional.pad(input_parallel,
(0, 0, 0, num_padding_tokens))
def otp_maybe_quant_comm(x):
# Reorganize the tensor so that the batch id and rank id correspond to each other.
chunk_num = len(self.reorgnized_batch_ids) * len(
self.reorgnized_batch_ids[0])
batch_size = x.size(0)
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
batch_size_per_chunk = batch_size // chunk_num
# Indices of reorganized tensor
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
reorganized_chunks = chunked[self.group_indices]
send_buf = reorganized_chunks.flatten(1, 2)
# all-to-all operation parameters
all2all_tp_size = self.odp_size
local_intermediate_size = x.size(1)
chunk_size = x.size(0) // all2all_tp_size
total_intermediate_size = local_intermediate_size * all2all_tp_size
# Create receive buffer
recv_buf = torch.empty(total_intermediate_size * chunk_size,
dtype=x.dtype,
device=x.device)
# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
group=self.odp_group.device_group)
return recv_buf.view(all2all_tp_size, chunk_size,
-1).transpose(0, 1).reshape(chunk_size, -1)
if not hasattr(self, "_quant_comm_config"):
self.layer._quant_comm_config = {}
self.layer._quant_comm_config[
"communication_fn"] = otp_maybe_quant_comm
actual_quant_method = getattr(self.quant_method, 'quant_method',
self.quant_method)
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
# Check if w8a8 quantization is enabled. If not, communicate immediately.
input_parallel = otp_maybe_quant_comm(input_parallel)
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
if self.tp_size > 1:
# flashcomm2 with reduce-scatter
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
else:
output = output_parallel
if not forward_context.sp_enabled:
# flashcomm1 not enabled
output = get_tp_group().all_gather(output, 0)
if num_padding_tokens > 0:
output = output[:-num_padding_tokens]
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
_HCOMM_INFO = None
@@ -487,13 +623,17 @@ def _get_column_parallel_op(
def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
if matmul_allreduce_enable():
return MatmulAllreduceRowParallelOp(layer)
if flashcomm2_enable():
if "o_proj" in prefix or "out_proj" in prefix:
return Flashcomm2OProjRowParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
@@ -509,6 +649,7 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
return None, 0, 1
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if direct == "row":