[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -31,15 +31,14 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
get_draft_graph_params, get_graph_params,
|
||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||
reach_layer_for_shard_weight_series,
|
||||
register_all_layers_to_shard_weight_series)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
|
||||
flashcomm2_o_shared_enabled, maybe_trans_nz,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
@@ -734,18 +733,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
@@ -762,6 +749,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||
if layer_name in kwargs:
|
||||
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
||||
@@ -833,9 +829,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
@@ -1445,9 +1441,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
@@ -1478,9 +1474,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
Reference in New Issue
Block a user