diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 2bffa44c..65affe8b 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -610,12 +610,16 @@ def _get_column_parallel_op( if enable_sp(): if "shared_expert" in prefix: return None - if "gate_up_proj" in prefix: - return SequenceColumnParallelOp(layer) - if "in_proj" in prefix: - return SequenceColumnParallelOp(layer) - if "qkv_proj" in prefix or "conv1d" in prefix: - return SequenceColumnParallelOp(layer) + sp_column_prefix = [ + "gate_up_proj", # first MLP of most LLMs + "in_proj", # gated deltanet of Qwen3 Next + "qkv_proj", # qkv linear of most LLMs + "conv1d", # gated deltanet of Qwen3 Next + "query_key_value", # qkv linear of Bailing + ] + for a_prefix in sp_column_prefix: + if a_prefix in prefix: + return SequenceColumnParallelOp(layer) return None @@ -637,8 +641,15 @@ def _get_row_parallel_op( if enable_sp(): if "shared_expert" in prefix: return None - if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix: - return SequenceRowParallelOp(layer) + sp_row_prefixes = [ + "o_proj", # attn output linear of most LLMs + "out_proj", # attn output linear of Qwen3 Next + "down_proj", # second MLP of most LLMs + "attention.dense", # attn output linear of Bailing + ] + for a_prefix in sp_row_prefixes: + if a_prefix in prefix: + return SequenceRowParallelOp(layer) return None