[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)

What this PR does / why we need it?

there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:

Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?

No

How was this patch tested?

e2e&ut

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-09-24 11:29:59 +08:00
committed by GitHub
parent e7618d9414
commit 6aa4253798
14 changed files with 90 additions and 215 deletions

View File

@@ -48,8 +48,9 @@ from vllm.distributed.parallel_state import get_tp_group
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
mlp_tp_enable, oproj_tp_enable)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
class CustomTensorParallelOp:
@@ -82,10 +83,17 @@ class CustomTensorParallelOp:
self.skip_bias_add = self.layer.skip_bias_add
self.return_bias = self.layer.return_bias
self.quant_method = self.layer.quant_method
self.prefix = self.layer.prefix
def apply_impl(self, input_):
raise NotImplementedError
# Replace layer.forward to customize the layer computation process.
def apply(self, input_):
raise NotImplementedError
output, output_bias = self.apply_impl(input_)
if not self.return_bias:
return output
return output, output_bias
class CustomColumnParallelOp(CustomTensorParallelOp):
@@ -113,6 +121,14 @@ class CustomRowParallelOp(CustomTensorParallelOp):
self.reduce_results = self.layer.reduce_results
self.input_size_per_partition = self.layer.input_size_per_partition
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if dense_optim_enable():
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
if not self.return_bias:
return output
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp):
@@ -123,7 +139,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
def comm_group(self):
return get_mlp_tp_group()
def apply(
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
@@ -134,14 +150,12 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
output = self.quant_method.apply(self.layer, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
def apply(
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
@@ -164,18 +178,16 @@ class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class DenseOptimQKVParallelOp(CustomColumnParallelOp):
class SequenceQKVParallelOp(CustomColumnParallelOp):
def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix
def apply(
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
@@ -201,8 +213,6 @@ class DenseOptimQKVParallelOp(CustomColumnParallelOp):
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
@@ -215,7 +225,7 @@ class MLPRowParallelOp(CustomRowParallelOp):
def comm_group(self):
return get_mlp_tp_group()
def apply(
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
@@ -234,8 +244,6 @@ class MLPRowParallelOp(CustomRowParallelOp):
output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
@@ -248,7 +256,7 @@ class OProjRowParallelOp(CustomRowParallelOp):
def comm_group(self):
return get_otp_group()
def apply(
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
@@ -294,8 +302,6 @@ class OProjRowParallelOp(CustomRowParallelOp):
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def update_attrs(self):
@@ -311,7 +317,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
super().__init__(layer)
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
def apply(
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
@@ -335,8 +341,6 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
@classmethod
@@ -359,13 +363,13 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
self.weight_t = self.layer.weight.t()
class DenseOptimRowParallelOp(CustomRowParallelOp):
class SequenceRowParallelOp(CustomRowParallelOp):
def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix
def apply(
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
@@ -391,12 +395,8 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def update_attrs(self):
@@ -407,23 +407,22 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
def get_column_parallel_op(
disable_tp, prefix, layer
) -> Tuple[
Optional[Union[MLPColumnParallelOp, DenseOptimMergedColumnParallelOp,
DenseOptimQKVParallelOp]], int, int]:
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:
if disable_tp:
return None, 0, 1
custom_op: Optional[Union[
MLPColumnParallelOp,
DenseOptimMergedColumnParallelOp,
DenseOptimQKVParallelOp,
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
]] = None
if "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and dense_optim_enable():
custom_op = DenseOptimMergedColumnParallelOp(layer)
elif dense_optim_enable():
custom_op = DenseOptimQKVParallelOp(layer, prefix)
elif "gate_up_proj" in prefix and enable_sp():
custom_op = SequenceMergedColumnParallelOp(layer)
elif enable_sp():
custom_op = SequenceQKVParallelOp(layer, prefix)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
@@ -435,21 +434,21 @@ def get_row_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
DenseOptimRowParallelOp]], int, int]:
SequenceRowParallelOp]], int, int]:
if disable_tp:
return None, 0, 1
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
DenseOptimRowParallelOp]] = None
SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer)
elif matmul_allreduce_enable():
custom_op = MatmulAllreduceRowParallelOp(layer)
elif dense_optim_enable():
custom_op = DenseOptimRowParallelOp(layer, prefix)
elif enable_sp():
custom_op = SequenceRowParallelOp(layer, prefix)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size