[Feature] optimize sp & qwen3 next support sp. (#3225)
This PR will accomplish the following tasks: **optimize SP** In the old version implementation, the first layer was all_reduce, which used rms to split chunks. We changed it to perform reduce_scatter on the embedding side, replace one all_reduce operation and one chunk with one reduce_scatter operation. **Support qwen3 next** Since Qwen3 Next includes a linear attention module, the prefix name of this module cannot take effect directly. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -64,7 +64,6 @@ class AscendRMSNorm(RMSNorm):
|
||||
import torch_npu
|
||||
|
||||
if residual is not None:
|
||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear)
|
||||
|
||||
@@ -34,8 +34,7 @@ from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.linear_op import (get_column_parallel_op,
|
||||
get_row_parallel_op)
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op
|
||||
|
||||
|
||||
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
|
||||
@@ -100,8 +99,8 @@ class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, _, tp_size = get_column_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self,
|
||||
"column")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@@ -173,8 +172,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "column")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.output_sizes = output_sizes
|
||||
assert all(output_size % self.tp_size == 0
|
||||
@@ -222,8 +221,8 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "row")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
@@ -304,8 +303,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
|
||||
disable_tp, prefix, self)
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "column")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
|
||||
@@ -20,13 +20,12 @@ Current class inheritance structure:
|
||||
CustomTensorParallelOp
|
||||
├── CustomColumnParallelOp
|
||||
│ ├── MLPColumnParallelOp
|
||||
│ ├── DenseOptimMergedColumnParallelOp
|
||||
│ └── DenseOptimQKVParallelOp
|
||||
│ ├── SequenceColumnParallelOp
|
||||
└── CustomRowParallelOp
|
||||
├── MLPRowParallelOp
|
||||
├── OProjRowParallelOp
|
||||
├── MatmulAllreduceRowParallelOp
|
||||
└── DenseOptimRowParallelOp
|
||||
└── SequenceRowParallelOp
|
||||
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
@@ -36,7 +35,7 @@ How to extend a new linear op? Taking column parallel op as an example:
|
||||
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -153,69 +152,6 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class SequenceQKVParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer, prefix):
|
||||
super().__init__(layer)
|
||||
self.prefix = prefix
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
layer_num = self.prefix.split('.')[2]
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
input_, layer_num != '0')
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
@@ -364,11 +300,35 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
||||
self.weight_t = self.layer.weight.t()
|
||||
|
||||
|
||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
class SequenceColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer, prefix):
|
||||
super().__init__(layer)
|
||||
self.prefix = prefix
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||
communication-computation fusion.
|
||||
"""
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
||||
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = self.comm_group.all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
def apply_impl(
|
||||
self, input_: torch.Tensor
|
||||
@@ -408,50 +368,55 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
self.reduce_results = self.layer.reduce_results
|
||||
|
||||
|
||||
def get_column_parallel_op(
|
||||
disable_tp, prefix, layer
|
||||
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
|
||||
SequenceQKVParallelOp]], int, int]:
|
||||
def _get_column_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
|
||||
if mlp_tp_enable() and "gate_up_proj" in prefix:
|
||||
return MLPColumnParallelOp(layer)
|
||||
if enable_sp():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
if "gate_up_proj" in prefix:
|
||||
return SequenceColumnParallelOp(layer)
|
||||
if "in_proj" in prefix:
|
||||
return SequenceColumnParallelOp(layer)
|
||||
if "qkv_proj" in prefix or "conv1d" in prefix:
|
||||
return SequenceColumnParallelOp(layer)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_row_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
return MLPRowParallelOp(layer)
|
||||
if "o_proj" in prefix and oproj_tp_enable():
|
||||
return OProjRowParallelOp(layer)
|
||||
if matmul_allreduce_enable():
|
||||
return MatmulAllreduceRowParallelOp(layer)
|
||||
if enable_sp():
|
||||
if "shared_expert" in prefix:
|
||||
return None
|
||||
if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix:
|
||||
return SequenceRowParallelOp(layer)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
if disable_tp:
|
||||
return None, 0, 1
|
||||
|
||||
custom_op: Optional[Union[
|
||||
MLPColumnParallelOp,
|
||||
SequenceMergedColumnParallelOp,
|
||||
SequenceQKVParallelOp,
|
||||
]] = None
|
||||
if "gate_up_proj" in prefix and mlp_tp_enable():
|
||||
custom_op = MLPColumnParallelOp(layer)
|
||||
elif "gate_up_proj" in prefix and enable_sp():
|
||||
custom_op = SequenceMergedColumnParallelOp(layer)
|
||||
elif enable_sp():
|
||||
custom_op = SequenceQKVParallelOp(layer, prefix)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_row_parallel_op(
|
||||
disable_tp, prefix, layer
|
||||
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]], int, int]:
|
||||
if disable_tp:
|
||||
return None, 0, 1
|
||||
|
||||
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]] = None
|
||||
if "down_proj" in prefix and mlp_tp_enable():
|
||||
custom_op = MLPRowParallelOp(layer)
|
||||
elif "o_proj" in prefix and oproj_tp_enable():
|
||||
custom_op = OProjRowParallelOp(layer)
|
||||
elif matmul_allreduce_enable():
|
||||
custom_op = MatmulAllreduceRowParallelOp(layer)
|
||||
elif enable_sp():
|
||||
custom_op = SequenceRowParallelOp(layer, prefix)
|
||||
if direct == "row":
|
||||
custom_op = _get_row_parallel_op(prefix, layer)
|
||||
|
||||
if direct == "column":
|
||||
custom_op = _get_column_parallel_op(prefix, layer)
|
||||
|
||||
if custom_op is not None:
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -15,27 +13,6 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
residual: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||
"when sp is enabled")
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
|
||||
|
||||
return residual
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
label: bool) -> torch.Tensor:
|
||||
try:
|
||||
@@ -187,12 +164,6 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: residual,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=lambda x, label: x,
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import divide, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed import divide
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -163,7 +163,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user