[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)

### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
  "layer_sharding": ["o_proj", "q_b_proj"]
}'
```

This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.



- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-08 09:05:02 +08:00
committed by GitHub
parent 20a8cf061b
commit f7db812ed7
13 changed files with 288 additions and 169 deletions

View File

@@ -2,14 +2,12 @@ from typing import Optional
import torch
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
get_pp_group, get_tp_group,
from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
get_world_group,
init_model_parallel_group)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
flashcomm2_o_shared_enabled)
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -25,8 +23,8 @@ _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
_FC3_QUANT_X: Optional[GroupCoordinator] = None
# shared_weight across rank groups
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
# shard_weight across rank groups
_SHARD_WEIGHT: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
@@ -37,7 +35,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
vllm_config = get_current_vllm_config()
global_tp_size = parallel_config.tensor_parallel_size
global_dp_size = parallel_config.data_parallel_size
global_pp_size = parallel_config.pipeline_parallel_size
@@ -48,6 +45,14 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
all_ranks = torch.arange(world_size).reshape(
-1, global_dp_size * parallel_config.prefill_context_parallel_size *
global_tp_size)
#TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
vllm_all_ranks = torch.arange(world_size).reshape(
-1,
global_dp_size,
global_pp_size,
parallel_config.prefill_context_parallel_size,
global_tp_size,
)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
pd_head_ratio = get_ascend_config().pd_head_ratio
@@ -148,38 +153,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if mlp_tp_size > 0:
_MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp")
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
group_ranks = []
for pp_idx in range(global_pp_size):
group = []
for dp_idx in range(global_dp_size):
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
for i in range(global_tp_size):
global_rank = base + i
group.append(global_rank)
group_ranks.append(group)
return init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=group_name)
global _SHARED_WEIGHT
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
# TODO: Extract and unify the logic across different communication group.
flashcomm2_otp_group_ranks = []
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group().world_size
global_dp_size = get_dp_group().world_size
global_pp_size = get_pp_group().world_size
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
global _FLASHCOMM2_OTP
global _FLASHCOMM2_ODP
@@ -187,7 +167,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
_FLASHCOMM2_ODP = get_tp_group()
if flashcomm2_otp_size > 1:
otp_group_ranks = []
odp_group_ranks: list[list[int]] = [
[] for _ in range(flashcomm2_otp_size * global_dp_size *
global_pp_size)
@@ -209,10 +188,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
odp_group_index = odp_base_index + j
odp_group_ranks[odp_group_index].append(
global_rank)
otp_group_ranks.append(ranks)
flashcomm2_otp_group_ranks.append(ranks)
_FLASHCOMM2_OTP = init_model_parallel_group(
otp_group_ranks,
flashcomm2_otp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_otp")
@@ -222,12 +201,50 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="flashcomm2_odp")
# Create shared weight group for flashcomm2 oproj
if flashcomm2_o_shared_enabled():
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
if _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group(
"flashcomm2_o_shared")
def create_shard_weight_group(
module_tp_group_ranks: None) -> GroupCoordinator:
# Argument module_tp_group_ranks: The module specific tensor parallel group.
# There are three situations.
# 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer.
# 2. If it is not None, and the module tp_group is same as the global tp_group.
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
group_ranks = []
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(
-1, global_pp_size)
if module_tp_group_ranks is None:
# If it is None, then the TP_size of this shard weight is 1.
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
group_ranks = [x.tolist() for x in shard_weight_group_ranks]
else:
# combine standard tp group and non-standard tp group to build shard_weight comm_group
module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1)
G = world_size // (global_pp_size * module_tp_group_ranks.size(1))
shard_weight_group_ranks = torch.stack(
[t.view(global_pp_size, G) for t in module_tp_tanspose_ranks],
dim=1)
group_ranks = shard_weight_group_ranks.view(-1, G).tolist()
return init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="shard_weight")
# Create shard weight group if enabled
if get_ascend_config().layer_sharding is not None:
global _SHARD_WEIGHT
if flashcomm2_enable():
if len(flashcomm2_otp_group_ranks) == 0:
FC2_group_ranks = None
else:
FC2_group_ranks = torch.tensor(
flashcomm2_otp_group_ranks).squeeze(0)
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
elif enable_dsa_cp():
# For dsa_cp, all shard layers are replicated.
_SHARD_WEIGHT = create_shard_weight_group(None)
else:
# For standard tp, use global tp group_ranks
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
@@ -280,11 +297,10 @@ def get_flashcomm2_odp_group() -> GroupCoordinator:
return _FLASHCOMM2_ODP
def get_shared_weight_group() -> GroupCoordinator:
assert _SHARED_WEIGHT is not None, (
"output shared weight parallel group for flashcomm2 is not initialized"
)
return _SHARED_WEIGHT
def get_shard_weight_group() -> GroupCoordinator:
assert _SHARD_WEIGHT is not None, (
"output shard weight parallel group for flashcomm2 is not initialized")
return _SHARD_WEIGHT
def get_p_tp_group() -> GroupCoordinator:
@@ -341,10 +357,10 @@ def destroy_ascend_model_parallel():
_FLASHCOMM2_ODP.destroy()
_FLASHCOMM2_ODP = None
global _SHARED_WEIGHT
if _SHARED_WEIGHT:
_SHARED_WEIGHT.destroy()
_SHARED_WEIGHT = None
global _SHARD_WEIGHT
if _SHARD_WEIGHT:
_SHARD_WEIGHT.destroy()
_SHARD_WEIGHT = None
global _FC3_QUANT_X
if _FC3_QUANT_X: