Files
xc-llm-ascend/vllm_ascend/ops/linear_op.py
wangbj127 2ad0ca52a6 Qwen3.5 MoE supports flashcomm v1 (#7644)
cherry pick from https://github.com/vllm-project/vllm-ascend/pull/7486
<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
Multimodal models like Qwen3.5 MoE does embedding in model_runner, so
when flash comm is enabled, the first AllGather operation should be
skipped.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
- vLLM version: v0.18.0
- vLLM main:
8b6325758c

---------

Signed-off-by: Wangbingjie <wangbj1207@126.com>
Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
2026-03-25 23:09:33 +08:00

752 lines
28 KiB
Python

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file extends the functionality of linear operations by encapsulating custom
communication groups and forward functions into classes (linear ops).
Current class inheritance structure:
CustomLinearOp
├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp
│ ├── SequenceColumnParallelOp
│ ├── Flashcomm2OshardQKVParallelOp
└── CustomRowParallelOp
│ ├── MLPRowParallelOp
│ ├── OProjRowParallelOp
| ├── Flashcomm2OProjRowParallelOp
│ ├── MatmulAllreduceRowParallelOp
│ └── SequenceRowParallelOp
└── CustomReplicatedOp
How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed,
override the comm_group method
3. Override the apply method according to requirements, which will replace the original linear.forward
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on
prefix and configuration judgments
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in
get_row_parallel_op.
"""
from functools import lru_cache
from types import SimpleNamespace
import regex as re
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.models.utils import extract_layer_index
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import (
get_flashcomm2_odp_group,
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group,
)
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
enable_sp,
flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
get_weight_prefetch_method,
is_vl_model,
matmul_allreduce_enable,
mlp_tp_enable,
oproj_tp_enable,
shared_expert_dp_enabled,
)
class CustomLinearOp:
def __init__(self, layer):
self.layer = layer
self.bias = None
self.skip_bias_add = None
self.return_bias = None
self.quant_method = None
# Custom communication group, while determining weight sharding
@property
def comm_group(self):
return get_tp_group()
@property
def tp_rank(self):
return self.comm_group.rank_in_group
@property
def tp_size(self):
return self.comm_group.world_size
# Update the attributes required by apply(), obtaining them from the layer.
# Call this after the layer completes its initialization, specifically at the end of layer.init().
def update_attrs(self):
if hasattr(self.layer, "bias"):
self.bias = self.layer.bias
self.skip_bias_add = self.layer.skip_bias_add
self.return_bias = self.layer.return_bias
self.quant_method = self.layer.quant_method
self.prefix = self.layer.prefix
def apply_impl(self, input_):
raise NotImplementedError
# Replace layer.forward to customize the layer computation process.
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if not self.return_bias:
return output
return output, output_bias
class CustomColumnParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
self.gather_output = None
def update_attrs(self):
super().update_attrs()
self.gather_output = self.layer.gather_output
class CustomRowParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
self.reduce_results = None
self.input_is_parallel = None
self.input_size_per_partition = None
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.reduce_results = self.layer.reduce_results
self.input_size_per_partition = self.layer.input_size_per_partition
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
)
if not self.return_bias:
return output
return output, output_bias
def get_input_parallel(self, input_: torch.Tensor) -> torch.Tensor:
if self.input_is_parallel:
return input_
split_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
return split_input[self.tp_rank].contiguous()
class CustomReplicatedOp(CustomLinearOp):
def apply_impl(self, input_):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_mlp_tp_group()
def apply_impl(
self,
input_: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_parallel = self.comm_group.all_gather(input_, 0)
output = self.quant_method.apply(self.layer, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_mlp_tp_group()
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
input_parallel = self.get_input_parallel(input_)
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias
output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_otp_group()
def apply_impl(
self,
input_: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
input_parallel = self.get_input_parallel(input_)
# Prepare tensors for all-to-all communication
local_batch_size = input_parallel.size(0)
chunk_size = self.input_size_per_partition
total_batch_size = local_batch_size * self.tp_size
# Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
send_buf = input_parallel.reshape(-1, self.tp_size, chunk_size).transpose(0, 1).contiguous().view(-1)
# Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size, dtype=input_parallel.dtype, device=input_parallel.device)
# Perform all-to-all communication
dist.all_to_all_single(recv_buf, send_buf, group=self.comm_group.device_group)
input_parallel = recv_buf.view(total_batch_size, chunk_size)
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
# otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
output = output.view(input_.shape[0], self.layer.output_size)
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition
class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
self.odp_group = get_flashcomm2_odp_group()
self.odp_size = self.odp_group.world_size
self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(get_tp_group().world_size)
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
self.layer._quant_comm_config = {}
@property
def comm_group(self):
return get_flashcomm2_otp_group()
@property
def tp_rank(self):
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
return 0
return self.comm_group.rank_in_group
@property
def tp_size(self):
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
return 1
return self.comm_group.world_size
def apply_impl(
self,
input_: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
"""Linear layer for Flashcomm2.
Input.ahspe = [batchsize*seqlength, headnum*headdim/TP]
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
"""
# Handle input parallelism - split or use as-is
input_parallel = self.get_input_parallel(input_)
# padding for all-to-all
num_padding_tokens = _EXTRA_CTX.pad_size
if num_padding_tokens > 0:
input_parallel = nn.functional.pad(input_parallel, (0, 0, 0, num_padding_tokens))
def otp_maybe_quant_comm(x):
# Reorganize the tensor so that the batch id and rank id correspond to each other.
chunk_num = len(self.reorgnized_batch_ids) * len(self.reorgnized_batch_ids[0])
batch_size = x.size(0)
assert batch_size % chunk_num == 0, f"Batch_size({batch_size}) must be divisible by chunk_num({chunk_num})"
batch_size_per_chunk = batch_size // chunk_num
# Indices of reorganized tensor
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
if self.otp_size != 1:
chunked = chunked[self.group_indices]
send_buf = chunked.flatten(1, 2)
# all-to-all operation parameters
all2all_tp_size = self.odp_size
local_intermediate_size = x.size(1)
chunk_size = x.size(0) // all2all_tp_size
total_intermediate_size = local_intermediate_size * all2all_tp_size
# Create receive buffer
recv_buf = torch.empty(total_intermediate_size * chunk_size, dtype=x.dtype, device=x.device)
# Perform all-to-all communication
dist.all_to_all_single(recv_buf, send_buf, group=self.odp_group.device_group)
return recv_buf.view(all2all_tp_size, chunk_size, -1).transpose(0, 1).reshape(chunk_size, -1)
if not hasattr(self, "_quant_comm_config"):
self.layer._quant_comm_config = {}
self.layer._quant_comm_config["communication_fn"] = otp_maybe_quant_comm
actual_quant_method = getattr(self.quant_method, "quant_method", self.quant_method)
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
# Check if w8a8 quantization is enabled. If not, communicate immediately.
input_parallel = otp_maybe_quant_comm(input_parallel)
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
# output_parallel shape: [bs/(TP/flashcomm2_otp_size), hiddenstate]
if self.tp_size > 1:
# flashcomm2 with reduce-scatter
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
else:
output = output_parallel
if not _EXTRA_CTX.flash_comm_v1_enabled:
# flashcomm1 not enabled
output = get_tp_group().all_gather(output, 0)
if num_padding_tokens > 0:
output = output[:-num_padding_tokens]
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
flashcomm2_oshard_manager.register_layer(self.layer, prefetch_step=1)
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
_HCOMM_INFO = None
def __init__(self, layer):
super().__init__(layer)
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
input_parallel = self.get_input_parallel(input_)
"""Calculate the output tensor of forward by considering
fusing communication and computation."""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(
input_parallel, self.layer.weight.t(), self.hcomm_info, bias=bias_
)
else:
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@classmethod
def get_hcomm_info(cls, group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group."""
if cls._HCOMM_INFO is not None:
return cls._HCOMM_INFO
rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
cls._HCOMM_INFO = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
else:
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
return cls._HCOMM_INFO
class SequenceColumnParallelOp(CustomColumnParallelOp):
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
need_all_gather = not (extract_layer_index(self.layer.prefix) == 0 and is_vl_model() and "attn" in self.prefix)
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, label=need_all_gather)
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
def __init__(self, layer):
super().__init__(layer)
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
"""Column-parallel linear with FlashComm2 OShard optimization."""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
if enable_sp():
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
# Trigger async broadcast before matmul to overlap communication.
flashcomm2_oshard_manager.trigger_broadcast_for_layer(self.layer.prefix)
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class SequenceRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
self.unique_prefix = None
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
input_parallel = self.get_input_parallel(input_)
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self.layer, input_parallel, bias=bias_)
else:
output = torch.ops.vllm.matmul_and_reduce(input_parallel, self.unique_prefix)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Parameter | None) -> torch.Tensor:
assert self.quant_method is not None
try:
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
mmrs_fusion = _EXTRA_CTX.mmrs_fusion
except AssertionError:
flash_comm_v1_enabled = False
mmrs_fusion = False
x = input_parallel
if not flash_comm_v1_enabled:
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)
pad_size = _EXTRA_CTX.pad_size
if pad_size > 0 and not (enable_dsa_cp() and "o_proj" in self.layer.prefix):
x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size
comm_mode = "aiv"
hcom_name = get_tp_group().device_group._get_backend(torch.device("npu")).get_hccl_comm_name(self.layer.tp_rank)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.method_adapters import AscendLinearMethod
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
# For unquant
if mmrs_fusion and isinstance(self.layer.quant_method, UnquantizedLinearMethod):
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
comm_mode=comm_mode,
)
if bias_ is not None:
output.add_(bias_)
# For w8a8 quant
elif mmrs_fusion and (
isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method, AscendW8A8LinearMethod)
):
if x.dtype != torch.int8:
x_quant = torch.ops.vllm.quantize(
x,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset,
)
else:
x_quant = x
quant_bias = self.layer.quant_bias
deq_scale = self.layer.deq_scale
output_dtype = torch.bfloat16
output = torch_npu.npu_mm_reduce_scatter_base(
x_quant,
self.layer.weight,
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
x2_scale=deq_scale,
output_dtype=output_dtype,
comm_mode=comm_mode,
)
output = torch.add(output, torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
else:
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
return output
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.reduce_results = self.layer.reduce_results
self.unique_prefix = self.layer.unique_prefix
class ShardedCPRowParallelOp(CustomRowParallelOp):
@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1, rank_in_group=0, device_group=None)
def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.layer.reduce_results = False
class ShardedCPColumnParallelOp(CustomColumnParallelOp):
@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1, rank_in_group=0, device_group=None)
def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def _get_column_parallel_op(
prefix, layer
) -> MLPColumnParallelOp | SequenceColumnParallelOp | ShardedCPColumnParallelOp | Flashcomm2OshardQKVParallelOp | None:
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
return ShardedCPColumnParallelOp(layer)
if "gate_up_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
return Flashcomm2OshardQKVParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
sp_column_prefix = [
"gate_up_proj", # first MLP of most LLMs
"in_proj", # gated deltanet of Qwen3 Next
"qkv_proj", # qkv linear of most LLMs
"conv1d", # gated deltanet of Qwen3 Next
"query_key_value", # qkv linear of Bailing
]
for a_prefix in sp_column_prefix:
if a_prefix in prefix:
return SequenceColumnParallelOp(layer)
return None
def _get_row_parallel_op(
prefix, layer
) -> (
MLPRowParallelOp
| OProjRowParallelOp
| Flashcomm2OProjRowParallelOp
| MatmulAllreduceRowParallelOp
| SequenceRowParallelOp
| ShardedCPRowParallelOp
| None
):
if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
if matmul_allreduce_enable():
return MatmulAllreduceRowParallelOp(layer)
if flashcomm2_enable():
if "o_proj" in prefix or "out_proj" in prefix:
return Flashcomm2OProjRowParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
sp_row_prefixes = [
"o_proj", # attn output linear of most LLMs
"out_proj", # attn output linear of Qwen3 Next
"down_proj", # second MLP of most LLMs
"attention.dense", # attn output linear of Bailing
]
for a_prefix in sp_row_prefixes:
if a_prefix in prefix:
return SequenceRowParallelOp(layer)
return None
def get_parallel_op(disable_tp, prefix, layer, direct):
if (
disable_tp
or ("shared_experts" in prefix and shared_expert_dp_enabled())
or ("shared_expert" in prefix and shared_expert_dp_enabled())
):
return None, 0, 1
custom_op: (
MLPColumnParallelOp
| SequenceColumnParallelOp
| MLPRowParallelOp
| OProjRowParallelOp
| Flashcomm2OProjRowParallelOp
| Flashcomm2OshardQKVParallelOp
| MatmulAllreduceRowParallelOp
| SequenceRowParallelOp
| ShardedCPRowParallelOp
| ShardedCPColumnParallelOp
| None
) = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)
if direct == "column":
custom_op = _get_column_parallel_op(prefix, layer)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix, layer) -> tuple[CustomReplicatedOp | None, int | None, int | None]:
if disable_tp:
return None, None, None
custom_op = CustomReplicatedOp(layer)
return custom_op, custom_op.tp_rank, custom_op.tp_size
def is_moe_layer(prefix: str) -> bool:
@lru_cache(maxsize=1)
def get_moe_params():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_text_config
n_routed_experts = getattr(config, "n_routed_experts", 0)
first_k_dense_replace = getattr(config, "first_k_dense_replace", float("inf"))
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
return n_routed_experts, first_k_dense_replace, moe_layer_freq
match = re.search(r"layers\.(\d+)\.", prefix)
if match is None:
return False
layer_idx = int(match.group(1))
n_routed_experts, first_k_dense_replace, moe_layer_freq = get_moe_params()
return n_routed_experts is not None and layer_idx >= first_k_dense_replace and layer_idx % moe_layer_freq == 0