[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)

### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-09 15:58:40 +08:00
committed by GitHub
parent e11ff8e535
commit 64d29875f9
4 changed files with 110 additions and 68 deletions

View File

@@ -38,6 +38,7 @@ Row parallel op follows a similar approach - inherit from RowColumnParallelOp an
import re
from functools import lru_cache
from types import SimpleNamespace
from typing import Optional, Union
import torch
@@ -59,7 +60,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
@@ -609,9 +610,60 @@ class SequenceRowParallelOp(CustomRowParallelOp):
self.unique_prefix = self.layer.unique_prefix
class ShardedCPRowParallelOp(CustomRowParallelOp):
@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1,
rank_in_group=0,
device_group=None)
def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.layer.reduce_results = False
class ShardedCPColumnParallelOp(CustomColumnParallelOp):
@property
def comm_group(self):
# fake comm group to bypass tp logic
return SimpleNamespace(world_size=1,
rank_in_group=0,
device_group=None)
def apply_impl(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
ShardedCPColumnParallelOp]]:
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
return ShardedCPColumnParallelOp(layer)
if "gate_up_proj" in prefix and mlp_tp_enable(
) and not is_moe_layer(prefix):
return MLPColumnParallelOp(layer)
@@ -636,7 +688,9 @@ def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
if enable_dsa_cp() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
@@ -670,7 +724,8 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
SequenceRowParallelOp, ShardedCPRowParallelOp,
ShardedCPColumnParallelOp]] = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)