[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -92,11 +92,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
||||
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
||||
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
|
||||
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
|
||||
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||
|
||||
Reference in New Issue
Block a user