[Refactor] Replace the implementations of o_proj, q_b_proj, and kv_b_proj with custom_op for sharded CP (#5698)
### What this PR does / why we need it?
Based on the Sharded-CP feature
PR:https://github.com/vllm-project/vllm-ascend/pull/4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
This PR officially integrates Deepseek V3.2's DSA-CP support on the
basis of https://github.com/vllm-project/vllm-ascend/pull/4702,
improving inference efficiency and scalability under mixed
prefill-decode workloads. The main improvements include:
- Replace the implementations of o_proj, q_b_proj, and kv_b_proj with
custom_op for TP=1.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ Row parallel op follows a similar approach - inherit from RowColumnParallelOp an
|
||||
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -59,7 +60,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
||||
get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
|
||||
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
|
||||
get_flashcomm2_reorgnized_batch_ids,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable, shared_expert_dp_enabled)
|
||||
@@ -609,9 +610,60 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
self.unique_prefix = self.layer.unique_prefix
|
||||
|
||||
|
||||
class ShardedCPRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
# fake comm group to bypass tp logic
|
||||
return SimpleNamespace(world_size=1,
|
||||
rank_in_group=0,
|
||||
device_group=None)
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self.layer, input_, bias_)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.layer.reduce_results = False
|
||||
|
||||
|
||||
class ShardedCPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
@property
|
||||
def comm_group(self):
|
||||
# fake comm group to bypass tp logic
|
||||
return SimpleNamespace(world_size=1,
|
||||
rank_in_group=0,
|
||||
device_group=None)
|
||||
|
||||
def apply_impl(
|
||||
self,
|
||||
input_,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self.layer, input_, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def _get_column_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
ShardedCPColumnParallelOp]]:
|
||||
if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix):
|
||||
return ShardedCPColumnParallelOp(layer)
|
||||
if "gate_up_proj" in prefix and mlp_tp_enable(
|
||||
) and not is_moe_layer(prefix):
|
||||
return MLPColumnParallelOp(layer)
|
||||
@@ -636,7 +688,9 @@ def _get_row_parallel_op(
|
||||
prefix, layer
|
||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]]:
|
||||
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
|
||||
if enable_dsa_cp() and "o_proj" in prefix:
|
||||
return ShardedCPRowParallelOp(layer)
|
||||
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
||||
return MLPRowParallelOp(layer)
|
||||
if "o_proj" in prefix and oproj_tp_enable():
|
||||
@@ -670,7 +724,8 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
Flashcomm2OProjRowParallelOp,
|
||||
MatmulAllreduceRowParallelOp,
|
||||
SequenceRowParallelOp]] = None
|
||||
SequenceRowParallelOp, ShardedCPRowParallelOp,
|
||||
ShardedCPColumnParallelOp]] = None
|
||||
if direct == "row":
|
||||
custom_op = _get_row_parallel_op(prefix, layer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user