[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)

### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
  "layer_sharding": ["o_proj", "q_b_proj"]
}'
```

This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.



- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
zzhxxx
2026-01-08 09:05:02 +08:00
committed by GitHub
parent 20a8cf061b
commit f7db812ed7
13 changed files with 288 additions and 169 deletions

View File

@@ -23,6 +23,7 @@ import math
import os
from contextlib import contextmanager, nullcontext
from enum import Enum
from functools import lru_cache
from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -1016,22 +1017,27 @@ def flashcomm2_enable() -> bool:
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
def flashcomm2_o_shared_enabled() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
if not flashcomm2_enable():
flashcomm2_oproj_shared = False
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
return 0
logger.info(
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
)
layer_sharding = ascend_config.layer_sharding or []
if layer_sharding:
if layer_sharding == ["o_proj"]:
logger.info_once(
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
)
else:
raise ValueError(
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
f"Found invalid layer_sharding: {layer_sharding}")
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once(
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
@@ -1054,13 +1060,10 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
)
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"FLASHCOMM2 primarily targets P-scenario deployments, "
"with additional support for hybrid deployment scenarios. "
"It is not applicable in D-scenario environments.")
if flashcomm2_oproj_shared:
logger.info("Enable FLASHCOMM2 with oproj_shared.")
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
)
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
return flashcomm2_oproj_tp_size
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
@@ -1160,4 +1163,20 @@ def singleton(cls):
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
return get_instance
@lru_cache(maxsize=1)
def get_current_model_config():
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.model_config
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
return is_ds_v32 and enable_sp()