[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -25,11 +25,11 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||
reach_layer_for_shard_weight_series,
|
||||
register_all_layers_to_shard_weight_series)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
@@ -374,22 +374,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
|
||||
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
||||
self._replace_linear_class_for_sfa_cp()
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="q_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.q_proj,
|
||||
prefetch_step=1)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||
if layer_name in kwargs:
|
||||
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(
|
||||
self.layer_sharding_kwargs)
|
||||
|
||||
# indexer param
|
||||
self.n_head: int = self.indexer.n_head # 64
|
||||
@@ -434,14 +429,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
||||
dispose_layer(self.kv_b_proj)
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
|
||||
if self.enable_mlapo:
|
||||
quant_method = getattr(
|
||||
@@ -751,10 +742,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
cos = attn_metadata.cos
|
||||
@@ -809,10 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
slot_mapping_cp)
|
||||
|
||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
|
||||
Reference in New Issue
Block a user