[Feature] optimize sp & qwen3 next support sp. (#3225)

This PR will accomplish the following tasks: 
**optimize SP**
In the old version implementation, the first layer was all_reduce, which
used rms to split chunks. We changed it to perform reduce_scatter on the
embedding side, replace one all_reduce operation and one chunk with one
reduce_scatter operation.
**Support qwen3 next**
Since Qwen3 Next includes a linear attention module, the prefix name of
this module cannot take effect directly.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-10-13 23:02:12 +08:00
committed by GitHub
parent 31682961af
commit 6972df5951
10 changed files with 140 additions and 193 deletions

View File

@@ -20,13 +20,12 @@ Current class inheritance structure:
CustomTensorParallelOp
├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp
│ ├── DenseOptimMergedColumnParallelOp
│ └── DenseOptimQKVParallelOp
│ ├── SequenceColumnParallelOp
└── CustomRowParallelOp
├── MLPRowParallelOp
├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp
└── DenseOptimRowParallelOp
└── SequenceRowParallelOp
How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
@@ -36,7 +35,7 @@ How to extend a new linear op? Taking column parallel op as an example:
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
"""
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@@ -153,69 +152,6 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
return output, output_bias
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class SequenceQKVParallelOp(CustomColumnParallelOp):
def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
@@ -364,11 +300,35 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
self.weight_t = self.layer.weight.t()
class SequenceRowParallelOp(CustomRowParallelOp):
class SequenceColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class SequenceRowParallelOp(CustomRowParallelOp):
def apply_impl(
self, input_: torch.Tensor
@@ -408,50 +368,55 @@ class SequenceRowParallelOp(CustomRowParallelOp):
self.reduce_results = self.layer.reduce_results
def get_column_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix:
return MLPColumnParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
if "gate_up_proj" in prefix:
return SequenceColumnParallelOp(layer)
if "in_proj" in prefix:
return SequenceColumnParallelOp(layer)
if "qkv_proj" in prefix or "conv1d" in prefix:
return SequenceColumnParallelOp(layer)
return None
def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
if matmul_allreduce_enable():
return MatmulAllreduceRowParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix:
return SequenceRowParallelOp(layer)
return None
def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp:
return None, 0, 1
custom_op: Optional[Union[
MLPColumnParallelOp,
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
]] = None
if "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and enable_sp():
custom_op = SequenceMergedColumnParallelOp(layer)
elif enable_sp():
custom_op = SequenceQKVParallelOp(layer, prefix)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_row_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]], int, int]:
if disable_tp:
return None, 0, 1
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer)
elif matmul_allreduce_enable():
custom_op = MatmulAllreduceRowParallelOp(layer)
elif enable_sp():
custom_op = SequenceRowParallelOp(layer, prefix)
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)
if direct == "column":
custom_op = _get_column_parallel_op(prefix, layer)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size