[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167)
### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.
In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.
performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>
### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |
example
`--additional_config={"oproj_tensor_parallel_size": 8}`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
eddaafc1c7
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
This commit is contained in:
@@ -491,9 +491,9 @@ def register_ascend_customop():
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear)
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
@@ -504,6 +504,12 @@ def register_ascend_customop():
|
||||
name="SiluAndMul")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding,
|
||||
name="RotaryEmbedding")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendColumnParallelLinear,
|
||||
name="ColumnParallelLinear")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendRowParallelLinear,
|
||||
name="RowParallelLinear")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear,
|
||||
name="MergedColumnParallelLinear")
|
||||
CustomOp.register_oot(
|
||||
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
|
||||
name="DeepseekScalingRotaryEmbedding")
|
||||
@@ -513,14 +519,6 @@ def register_ascend_customop():
|
||||
name="ParallelLMHead")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor,
|
||||
name="LogitsProcessor")
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE:
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendMlpColumnParallelLinear,
|
||||
name="ColumnParallelLinear")
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendMlpRowParallelLinear,
|
||||
name="RowParallelLinear")
|
||||
CustomOp.register_oot(
|
||||
_decorated_op_cls=AscendMlpMergedColumnParallelLinear,
|
||||
name="MergedColumnParallelLinear")
|
||||
|
||||
from vllm_ascend.ops.layernorm import AscendRMSNorm
|
||||
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
|
||||
@@ -562,3 +560,15 @@ def get_ascend_soc_version():
|
||||
|
||||
def lmhead_tp_enable() -> bool:
|
||||
return get_ascend_config().lmhead_tensor_parallel_size is not None
|
||||
|
||||
|
||||
def oproj_tp_enable() -> bool:
|
||||
return get_ascend_config().oproj_tensor_parallel_size is not None
|
||||
|
||||
|
||||
def mlp_tp_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE
|
||||
|
||||
|
||||
def matmul_allreduce_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
|
||||
Reference in New Issue
Block a user