diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index f62f8b18..1472c6a3 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -49,6 +49,7 @@ The following table lists additional configuration options available in vLLM Asc | `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. | | `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. | | `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | +| `layer_sharding` | dict | `{}` | Configuration options for layer sharding linear | The details of each configuration option are as follows: diff --git a/docs/source/user_guide/feature_guide/images/layer_sharding.png b/docs/source/user_guide/feature_guide/images/layer_sharding.png new file mode 100644 index 00000000..bc681be5 Binary files /dev/null and b/docs/source/user_guide/feature_guide/images/layer_sharding.png differ diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index bc452aaa..96872abf 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -19,6 +19,7 @@ external_dp large_scale_ep ucm_deployment Fine_grained_TP +layer_sharding speculative_decoding context_parallel ::: diff --git a/docs/source/user_guide/feature_guide/layer_sharding.md b/docs/source/user_guide/feature_guide/layer_sharding.md new file mode 100644 index 00000000..3d7bc160 --- /dev/null +++ b/docs/source/user_guide/feature_guide/layer_sharding.md @@ -0,0 +1,73 @@ +--- +title: Layer Sharding Guide +--- + +# Overview + +**Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights. + +Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**: +- The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group. +- Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass. + +As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**. + +This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for: +- Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers); +- Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer; +- Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting. + +--- + +## Flowchart +![layer shard](./images/layer_sharding.png) + +> **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading. + +--- + +# Getting Started + +To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use: + +```bash +--additional-config '{ + "layer_sharding": ["o_proj", "q_b_proj"] +}' +``` + +--- + +# Supported Scenarios + +This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases: + +## FlashComm2-enabled + +When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices. + +**Example configuration:** + +```bash +export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 +vllm serve \ + --model DeepSeek-V3/R1 \ + --additional-config '{ + "layer_sharding": ["o_proj"] + }' +``` + +## DSA-CP-enabled + +With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory. + +**Example configuration:** + +```bash +export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 +vllm serve \ + --model DeepSeek-V3.2 \ + --additional-config '{ + "layer_sharding": ["q_b_proj", "o_proj"] + }' +``` diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 1e7a9312..30914efa 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -25,14 +25,10 @@ def mock_distributed(): patch('torch.distributed.get_world_size', return_value=16), \ patch('torch.distributed.get_backend', return_value='nccl'), \ patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ - patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \ - patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \ - patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group: + patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group: mock_group.return_value.local_rank = 0 mock_group.return_value.device_group = MagicMock() mock_tp_group.return_value.world_size = 4 - mock_dp_group.return_value.world_size = 2 - mock_pp_group.return_value.world_size = 2 yield @@ -50,7 +46,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_vllm_config.kv_transfer_config.is_kv_producer = True mock_envs_ascend = MagicMock() mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2 - mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0 mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index ad53e687..81e50468 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -51,6 +51,12 @@ class AscendConfig: "weight_prefetch_config", {}) self.weight_prefetch_config = WeightPrefetchConfig( weight_prefetch_config) + self.layer_sharding = additional_config.get("layer_sharding", None) + logger.info_once( + f"Linear layer sharding enabled with config: {self.layer_sharding}. " + "Note: This feature works optimally with FLASHCOMM2 and DSA-CP enabled; " + "using it without these features may result in significant performance degradation." + ) # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config self.expert_map_path = additional_config.get("expert_map_path", None) @@ -111,7 +117,7 @@ class AscendConfig: self.SLO_limits_for_dynamic_batch = additional_config.get( "SLO_limits_for_dynamic_batch", -1) from vllm_ascend.utils import get_flashcomm2_config_and_validate - self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate( + self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate( self, vllm_config) self.enable_npugraph_ex = additional_config.get( "enable_npugraph_ex", False) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 04b59dd5..23564888 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -31,15 +31,14 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) +from vllm_ascend.ops.layer_shard_linear import ( + is_hidden_layer, post_process_after_loading_for_shard_weight_series, + reach_layer_for_shard_weight_series, + register_all_layers_to_shard_weight_series) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.ops.shared_weight_layer import ( - is_hidden_layer, post_process_after_loading_for_shared_weight_series, - reach_layer_for_shared_weight_series, - register_layer_to_shared_weight_series) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, - flashcomm2_o_shared_enabled, maybe_trans_nz, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -734,18 +733,6 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_b_proj = kwargs['kv_b_proj'] self.o_proj = kwargs['o_proj'] self.vllm_config = get_current_vllm_config() - self.fc2_o_shared_enable = flashcomm2_o_shared_enabled() - - if self.fc2_o_shared_enable and is_hidden_layer( - self.vllm_config, self.o_proj): - from vllm_ascend.distributed.parallel_state import \ - get_shared_weight_group - register_layer_to_shared_weight_series( - series_name="o_proj", - group=get_shared_weight_group(), - layer=self.o_proj, - prefetch_step=1) - self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.q_a_layernorm = kwargs.get('q_a_layernorm', None) @@ -762,6 +749,15 @@ class AscendMLAImpl(MLAAttentionImpl): self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + self.layer_sharding_kwargs = [] + for layer_name in (get_ascend_config().layer_sharding or []): + if layer_name in kwargs: + self.layer_sharding_kwargs.append(kwargs[layer_name]) + else: + logger.warning_once( + f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration" + ) + register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) def _v_up_proj(self, x): # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) @@ -833,9 +829,9 @@ class AscendMLAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - if self.fc2_o_shared_enable and is_hidden_layer( - self.vllm_config, self.o_proj): - post_process_after_loading_for_shared_weight_series(self.o_proj) + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + post_process_after_loading_for_shard_weight_series(layer) def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ @@ -1445,9 +1441,9 @@ class AscendMLAImpl(MLAAttentionImpl): kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( kv_no_split.contiguous(), need_gather_q_kv) - if self.fc2_o_shared_enable and is_hidden_layer( - self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) decode_preprocess_res = None prefill_preprocess_res = None @@ -1478,9 +1474,9 @@ class AscendMLAImpl(MLAAttentionImpl): assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. - if self.fc2_o_shared_enable and is_hidden_layer( - self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) return output.fill_(0) forward_context = get_forward_context() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 7430575d..1bfe8e67 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -25,11 +25,11 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) +from vllm_ascend.ops.layer_shard_linear import ( + is_hidden_layer, post_process_after_loading_for_shard_weight_series, + reach_layer_for_shard_weight_series, + register_all_layers_to_shard_weight_series) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.ops.shared_weight_layer import ( - is_hidden_layer, post_process_after_loading_for_shared_weight_series, - reach_layer_for_shared_weight_series, - register_layer_to_shared_weight_series) from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod @@ -374,22 +374,17 @@ class AscendSFAImpl(MLAAttentionImpl): if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size - # TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 self._replace_linear_class_for_sfa_cp() - from vllm_ascend.distributed.parallel_state import \ - get_shared_weight_group - if is_hidden_layer(self.vllm_config, self.q_proj): - register_layer_to_shared_weight_series( - series_name="q_proj", - group=get_shared_weight_group(), - layer=self.q_proj, - prefetch_step=1) - if is_hidden_layer(self.vllm_config, self.o_proj): - register_layer_to_shared_weight_series( - series_name="o_proj", - group=get_shared_weight_group(), - layer=self.o_proj, - prefetch_step=1) + self.layer_sharding_kwargs = [] + for layer_name in (get_ascend_config().layer_sharding or []): + if layer_name in kwargs: + self.layer_sharding_kwargs.append(kwargs[layer_name]) + else: + logger.warning_once( + f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration" + ) + register_all_layers_to_shard_weight_series( + self.layer_sharding_kwargs) # indexer param self.n_head: int = self.indexer.n_head # 64 @@ -434,14 +429,10 @@ class AscendSFAImpl(MLAAttentionImpl): # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) - if self.enable_sfa_cp: - if is_hidden_layer(self.vllm_config, self.q_proj): - post_process_after_loading_for_shared_weight_series( - self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - post_process_after_loading_for_shared_weight_series( - self.o_proj) + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + post_process_after_loading_for_shard_weight_series(layer) if self.enable_mlapo: quant_method = getattr( @@ -751,10 +742,9 @@ class AscendSFAImpl(MLAAttentionImpl): if attn_metadata is None: # Profiling run. if self.enable_sfa_cp and not forward_context.in_profile_run: - if is_hidden_layer(self.vllm_config, self.q_proj): - reach_layer_for_shared_weight_series(self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) return output.fill_(0) has_prefill = attn_metadata.has_prefill cos = attn_metadata.cos @@ -809,10 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl): slot_mapping_cp) if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: - if is_hidden_layer(self.vllm_config, self.q_proj): - reach_layer_for_shared_weight_series(self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 4d50cec0..9932867c 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -2,14 +2,12 @@ from typing import Optional import torch from vllm.config import ParallelConfig, get_current_vllm_config -from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, - get_pp_group, get_tp_group, +from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group, get_world_group, init_model_parallel_group) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (enable_sp, flashcomm2_enable, - flashcomm2_o_shared_enabled) +from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -25,8 +23,8 @@ _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None _FLASHCOMM2_ODP: Optional[GroupCoordinator] = None _FC3_QUANT_X: Optional[GroupCoordinator] = None -# shared_weight across rank groups -_SHARED_WEIGHT: Optional[GroupCoordinator] = None +# shard_weight across rank groups +_SHARD_WEIGHT: Optional[GroupCoordinator] = None _P_TP: Optional[GroupCoordinator] = None @@ -37,7 +35,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) - vllm_config = get_current_vllm_config() global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size @@ -48,6 +45,14 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): all_ranks = torch.arange(world_size).reshape( -1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size) + #TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future. + vllm_all_ranks = torch.arange(world_size).reshape( + -1, + global_dp_size, + global_pp_size, + parallel_config.prefill_context_parallel_size, + global_tp_size, + ) pd_tp_ratio = get_ascend_config().pd_tp_ratio pd_head_ratio = get_ascend_config().pd_head_ratio @@ -148,38 +153,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): if mlp_tp_size > 0: _MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp") - def _create_shared_weight_group(group_name: str) -> GroupCoordinator: - #This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference - group_ranks = [] - for pp_idx in range(global_pp_size): - group = [] - for dp_idx in range(global_dp_size): - base = (dp_idx * global_pp_size + pp_idx) * global_tp_size - for i in range(global_tp_size): - global_rank = base + i - group.append(global_rank) - group_ranks.append(group) - - return init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name=group_name) - - global _SHARED_WEIGHT - # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 - is_ds_v32 = hasattr(vllm_config.model_config.hf_text_config, "index_topk") - if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None: - _SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight") # TODO: Extract and unify the logic across different communication group. + flashcomm2_otp_group_ranks = [] if flashcomm2_enable(): flashcomm2_otp_size = get_ascend_config( ).flashcomm2_oproj_tensor_parallel_size - global_tp_size = get_tp_group().world_size - global_dp_size = get_dp_group().world_size - global_pp_size = get_pp_group().world_size num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) - global _FLASHCOMM2_OTP global _FLASHCOMM2_ODP @@ -187,7 +167,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): _FLASHCOMM2_ODP = get_tp_group() if flashcomm2_otp_size > 1: - otp_group_ranks = [] odp_group_ranks: list[list[int]] = [ [] for _ in range(flashcomm2_otp_size * global_dp_size * global_pp_size) @@ -209,10 +188,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): odp_group_index = odp_base_index + j odp_group_ranks[odp_group_index].append( global_rank) - otp_group_ranks.append(ranks) + flashcomm2_otp_group_ranks.append(ranks) _FLASHCOMM2_OTP = init_model_parallel_group( - otp_group_ranks, + flashcomm2_otp_group_ranks, get_world_group().local_rank, backend, group_name="flashcomm2_otp") @@ -222,12 +201,50 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="flashcomm2_odp") - # Create shared weight group for flashcomm2 oproj - if flashcomm2_o_shared_enabled(): - assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1" - if _SHARED_WEIGHT is None: - _SHARED_WEIGHT = _create_shared_weight_group( - "flashcomm2_o_shared") + def create_shard_weight_group( + module_tp_group_ranks: None) -> GroupCoordinator: + # Argument module_tp_group_ranks: The module specific tensor parallel group. + # There are three situations. + # 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer. + # 2. If it is not None, and the module tp_group is same as the global tp_group. + # 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp) + group_ranks = [] + pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape( + -1, global_pp_size) + if module_tp_group_ranks is None: + # If it is None, then the TP_size of this shard weight is 1. + shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0) + group_ranks = [x.tolist() for x in shard_weight_group_ranks] + else: + # combine standard tp group and non-standard tp group to build shard_weight comm_group + module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1) + G = world_size // (global_pp_size * module_tp_group_ranks.size(1)) + shard_weight_group_ranks = torch.stack( + [t.view(global_pp_size, G) for t in module_tp_tanspose_ranks], + dim=1) + group_ranks = shard_weight_group_ranks.view(-1, G).tolist() + return init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="shard_weight") + + # Create shard weight group if enabled + if get_ascend_config().layer_sharding is not None: + global _SHARD_WEIGHT + if flashcomm2_enable(): + if len(flashcomm2_otp_group_ranks) == 0: + FC2_group_ranks = None + else: + FC2_group_ranks = torch.tensor( + flashcomm2_otp_group_ranks).squeeze(0) + _SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks) + elif enable_dsa_cp(): + # For dsa_cp, all shard layers are replicated. + _SHARD_WEIGHT = create_shard_weight_group(None) + else: + # For standard tp, use global tp group_ranks + tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size) + _SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks) if get_ascend_config().multistream_overlap_gate: global _FC3_QUANT_X @@ -280,11 +297,10 @@ def get_flashcomm2_odp_group() -> GroupCoordinator: return _FLASHCOMM2_ODP -def get_shared_weight_group() -> GroupCoordinator: - assert _SHARED_WEIGHT is not None, ( - "output shared weight parallel group for flashcomm2 is not initialized" - ) - return _SHARED_WEIGHT +def get_shard_weight_group() -> GroupCoordinator: + assert _SHARD_WEIGHT is not None, ( + "output shard weight parallel group for flashcomm2 is not initialized") + return _SHARD_WEIGHT def get_p_tp_group() -> GroupCoordinator: @@ -341,10 +357,10 @@ def destroy_ascend_model_parallel(): _FLASHCOMM2_ODP.destroy() _FLASHCOMM2_ODP = None - global _SHARED_WEIGHT - if _SHARED_WEIGHT: - _SHARED_WEIGHT.destroy() - _SHARED_WEIGHT = None + global _SHARD_WEIGHT + if _SHARD_WEIGHT: + _SHARD_WEIGHT.destroy() + _SHARD_WEIGHT = None global _FC3_QUANT_X if _FC3_QUANT_X: diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index 6b4b894e..70c57d28 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -2,10 +2,10 @@ import os import torch import torch.distributed as dist +from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context -from vllm_ascend.distributed.parallel_state import (get_dp_group, - get_fc3_quant_x_group, +from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group, get_p_tp_group) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index d4c8bf44..2c1fae14 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -92,11 +92,6 @@ env_variables: Dict[str, Callable[[], Any]] = { # between this feature and FLASHCOMM1, please refer to the feature guide in the documentation. "VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)), - # This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature, - # which can eliminate redundant storage of weights. More detailed information can be found in PR#4188. - # We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited. - "VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))), # Whether to enable MLP weight prefetch, only used in small concurrency. "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), diff --git a/vllm_ascend/ops/shared_weight_layer.py b/vllm_ascend/ops/layer_shard_linear.py similarity index 76% rename from vllm_ascend/ops/shared_weight_layer.py rename to vllm_ascend/ops/layer_shard_linear.py index 1dc2e88d..e1224522 100644 --- a/vllm_ascend/ops/shared_weight_layer.py +++ b/vllm_ascend/ops/layer_shard_linear.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Callable, Optional +from functools import lru_cache +from typing import Callable, List, Optional import torch import torch.distributed as dist @@ -7,6 +8,8 @@ from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.models.utils import extract_layer_index +from vllm_ascend.distributed.parallel_state import get_shard_weight_group + def dispose_tensor(x: torch.Tensor): x.set_(torch.empty([], device=x.device, dtype=x.dtype)) @@ -26,17 +29,17 @@ class LayerMetadata: @dataclass -class SharedWindowMetadata: - """Metadata for a shared window. +class ShardWindowMetadata: + """Metadata for a shard window. """ - weight: torch.Tensor # The weight tensor to be shared by layers. + weight: torch.Tensor # The weight tensor to be shard by layers. data_layer_idx: int # The index of the layer this window's weight is equal to. work: Optional[torch.distributed.Work] # The asynchronous broadcast work. @dataclass class SeriesMetadata: - """Metadata for a weight shared series. + """Metadata for a weight shard series. """ group: GroupCoordinator start_layer: int @@ -45,8 +48,8 @@ class SeriesMetadata: prefetch_step: int dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. layers: list[LayerMetadata] - shared_windows: list[ - SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. + shard_windows: list[ + ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. window_offset: int # The index of the window for the next coming layer. def is_source(self, layer_idx) -> bool: @@ -54,7 +57,7 @@ class SeriesMetadata: def post_process_after_loading(self): # This method only needs to be called once per series. - if self.shared_windows: + if self.shard_windows: return self.layers.sort(key=lambda x: x.layer_idx) @@ -83,8 +86,8 @@ class SeriesMetadata: step = layer_idx - self.start_layer if step < self.prefetch_step: # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. - self.shared_windows.append( - SharedWindowMetadata( + self.shard_windows.append( + ShardWindowMetadata( weight=layer.weight.clone().detach(), data_layer_idx=layer_idx, work=None, @@ -92,12 +95,12 @@ class SeriesMetadata: layer.window_idx = step # When the layer not intended to be stored in this device, link to the corresponding window's tensor. if not is_source: - layer.weight.set_(self.shared_windows[-1].weight) + layer.weight.set_(self.shard_windows[-1].weight) else: # Build one more window for prefetch. The weight is useless, so just keep the shape. if step == self.prefetch_step: - self.shared_windows.append( - SharedWindowMetadata( + self.shard_windows.append( + ShardWindowMetadata( weight=torch.empty_like(layer.weight), data_layer_idx=-1, work=None, @@ -115,7 +118,7 @@ class SeriesMetadata: next_layer = self.layers[next_layer_idx - self.start_layer] # The index of the window to store the weight for the coming layer. next_layer.window_idx = self.window_offset - window = self.shared_windows[next_layer.window_idx] + window = self.shard_windows[next_layer.window_idx] # When the layer not intended to be stored in this device, link to the corresponding window's tensor. if not self.is_source(next_layer_idx): next_layer.weight.set_(window.weight) @@ -133,10 +136,10 @@ class SeriesMetadata: def wait_weight(self, layer_idx: int): # Find the asynchronous broadcast work and wait for it. - assert self.shared_windows - window = self.shared_windows[self.layers[layer_idx - - self.start_layer].window_idx] - # Make sure the data in the corresponding shared window is for the current layer. + assert self.shard_windows + window = self.shard_windows[self.layers[layer_idx - + self.start_layer].window_idx] + # Make sure the data in the corresponding shard window is for the current layer. assert window.data_layer_idx == layer_idx if window.work is not None: window.work.wait() @@ -168,13 +171,13 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, """ -Register linear layers into a shared storage series. +Register linear layers into a shard storage series. In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. -After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. +After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization. -During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. +During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: - start_layer is the index of the first layer in the series (inclusive). @@ -182,7 +185,7 @@ Note: The layers are managed as a circular buffer. The index of the layer to pre - total_layers = end_layer - start_layer - prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer -To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. +To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series. Arguments: series_name: This name identifies which series this layer belongs to. @@ -192,7 +195,7 @@ Arguments: """ -def register_layer_to_shared_weight_series( +def register_layer_to_shard_weight_series( series_name: str, group: GroupCoordinator, layer: LinearBase, @@ -208,7 +211,7 @@ def register_layer_to_shared_weight_series( prefetch_step=prefetch_step, dummy_weight=torch.empty_like(layer.weight), layers=[], - shared_windows=[], + shard_windows=[], window_offset=prefetch_step, ) series = _series_dict[series_name] @@ -236,17 +239,42 @@ def register_layer_to_shared_weight_series( ) -def post_process_after_loading_for_shared_weight_series(layer: LinearBase): +def post_process_after_loading_for_shard_weight_series(layer: LinearBase): ext = _layer_external_dict[id(layer)] ext.series.post_process_after_loading() -def reach_layer_for_shared_weight_series(layer: LinearBase): +def reach_layer_for_shard_weight_series(layer: LinearBase): ext = _layer_external_dict[id(layer)] ext.series.reach_layer(ext.layer_idx) -def is_hidden_layer(vllm_config, layer: LinearBase) -> bool: - num_hidden_layers = vllm_config.model_config.hf_text_config.num_hidden_layers +def wait_layer_for_shard_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.wait_weight(ext.layer_idx) + + +@lru_cache(maxsize=1) +def get_current_model_num_hidden_layers() -> int: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + return vllm_config.model_config.get_total_num_hidden_layers() + + +def is_hidden_layer(layer: LinearBase) -> bool: + num_hidden_layers = get_current_model_num_hidden_layers() layer_idx = extract_layer_index(layer.prefix) return layer_idx < num_hidden_layers + + +def register_all_layers_to_shard_weight_series( + layer_sharding: List[LinearBase], ): + for curr_layer in (layer_sharding or []): + if is_hidden_layer(curr_layer): + layer_name = curr_layer.prefix.split('.')[-1] + register_layer_to_shard_weight_series( + series_name=layer_name, + group=get_shard_weight_group(), + layer=curr_layer, + prefetch_step=1, + ) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 80565554..09266716 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -23,6 +23,7 @@ import math import os from contextlib import contextmanager, nullcontext from enum import Enum +from functools import lru_cache from threading import Lock from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union @@ -1016,22 +1017,27 @@ def flashcomm2_enable() -> bool: return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0 -def flashcomm2_o_shared_enabled() -> bool: - return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED - - def get_flashcomm2_config_and_validate(ascend_config, vllm_config): flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE global_tp_size = vllm_config.parallel_config.tensor_parallel_size - flashcomm2_oproj_shared = flashcomm2_o_shared_enabled() if not flashcomm2_enable(): - flashcomm2_oproj_shared = False - return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared + return 0 logger.info( - f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}" + f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}" ) + + layer_sharding = ascend_config.layer_sharding or [] + if layer_sharding: + if layer_sharding == ["o_proj"]: + logger.info_once( + "Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption." + ) + else: + raise ValueError( + "FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! " + f"Found invalid layer_sharding: {layer_sharding}") if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1: logger.warning_once( "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." @@ -1054,13 +1060,10 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config): ) if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer: raise AssertionError( - "FLASHCOMM2 primarily targets P-scenario deployments, " - "with additional support for hybrid deployment scenarios. " - "It is not applicable in D-scenario environments.") - if flashcomm2_oproj_shared: - logger.info("Enable FLASHCOMM2 with oproj_shared.") + "FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments." + ) - return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared + return flashcomm2_oproj_tp_size def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: @@ -1160,4 +1163,20 @@ def singleton(cls): instances[cls] = cls(*args, **kwargs) return instances[cls] - return get_instance \ No newline at end of file + return get_instance + + +@lru_cache(maxsize=1) +def get_current_model_config(): + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + return vllm_config.model_config + + +#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1 +@lru_cache(maxsize=1) +def enable_dsa_cp() -> bool: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") + return is_ds_v32 and enable_sp()