[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ import math
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
@@ -1016,22 +1017,27 @@ def flashcomm2_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||
|
||||
|
||||
def flashcomm2_o_shared_enabled() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
|
||||
|
||||
|
||||
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
|
||||
|
||||
if not flashcomm2_enable():
|
||||
flashcomm2_oproj_shared = False
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
|
||||
)
|
||||
|
||||
layer_sharding = ascend_config.layer_sharding or []
|
||||
if layer_sharding:
|
||||
if layer_sharding == ["o_proj"]:
|
||||
logger.info_once(
|
||||
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
|
||||
f"Found invalid layer_sharding: {layer_sharding}")
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
||||
@@ -1054,13 +1060,10 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
)
|
||||
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, "
|
||||
"with additional support for hybrid deployment scenarios. "
|
||||
"It is not applicable in D-scenario environments.")
|
||||
if flashcomm2_oproj_shared:
|
||||
logger.info("Enable FLASHCOMM2 with oproj_shared.")
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||
)
|
||||
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
@@ -1160,4 +1163,20 @@ def singleton(cls):
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
||||
return get_instance
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_model_config():
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.model_config
|
||||
|
||||
|
||||
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
return is_ds_v32 and enable_sp()
|
||||
|
||||
Reference in New Issue
Block a user