[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -49,6 +49,7 @@ The following table lists additional configuration options available in vLLM Asc
|
|||||||
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
|
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
|
||||||
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
|
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
|
||||||
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||||
|
| `layer_sharding` | dict | `{}` | Configuration options for layer sharding linear |
|
||||||
|
|
||||||
The details of each configuration option are as follows:
|
The details of each configuration option are as follows:
|
||||||
|
|
||||||
|
|||||||
BIN
docs/source/user_guide/feature_guide/images/layer_sharding.png
Normal file
BIN
docs/source/user_guide/feature_guide/images/layer_sharding.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 156 KiB |
@@ -19,6 +19,7 @@ external_dp
|
|||||||
large_scale_ep
|
large_scale_ep
|
||||||
ucm_deployment
|
ucm_deployment
|
||||||
Fine_grained_TP
|
Fine_grained_TP
|
||||||
|
layer_sharding
|
||||||
speculative_decoding
|
speculative_decoding
|
||||||
context_parallel
|
context_parallel
|
||||||
:::
|
:::
|
||||||
|
|||||||
73
docs/source/user_guide/feature_guide/layer_sharding.md
Normal file
73
docs/source/user_guide/feature_guide/layer_sharding.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
---
|
||||||
|
title: Layer Sharding Guide
|
||||||
|
---
|
||||||
|
|
||||||
|
# Overview
|
||||||
|
|
||||||
|
**Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights.
|
||||||
|
|
||||||
|
Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**:
|
||||||
|
- The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group.
|
||||||
|
- Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass.
|
||||||
|
|
||||||
|
As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**.
|
||||||
|
|
||||||
|
This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for:
|
||||||
|
- Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers);
|
||||||
|
- Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer;
|
||||||
|
- Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Flowchart
|
||||||
|

|
||||||
|
|
||||||
|
> **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Getting Started
|
||||||
|
|
||||||
|
To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--additional-config '{
|
||||||
|
"layer_sharding": ["o_proj", "q_b_proj"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Supported Scenarios
|
||||||
|
|
||||||
|
This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases:
|
||||||
|
|
||||||
|
## FlashComm2-enabled
|
||||||
|
|
||||||
|
When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices.
|
||||||
|
|
||||||
|
**Example configuration:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
|
||||||
|
vllm serve \
|
||||||
|
--model DeepSeek-V3/R1 \
|
||||||
|
--additional-config '{
|
||||||
|
"layer_sharding": ["o_proj"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## DSA-CP-enabled
|
||||||
|
|
||||||
|
With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory.
|
||||||
|
|
||||||
|
**Example configuration:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
|
||||||
|
vllm serve \
|
||||||
|
--model DeepSeek-V3.2 \
|
||||||
|
--additional-config '{
|
||||||
|
"layer_sharding": ["q_b_proj", "o_proj"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
@@ -25,14 +25,10 @@ def mock_distributed():
|
|||||||
patch('torch.distributed.get_world_size', return_value=16), \
|
patch('torch.distributed.get_world_size', return_value=16), \
|
||||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
|
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group:
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
|
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
|
|
||||||
mock_group.return_value.local_rank = 0
|
mock_group.return_value.local_rank = 0
|
||||||
mock_group.return_value.device_group = MagicMock()
|
mock_group.return_value.device_group = MagicMock()
|
||||||
mock_tp_group.return_value.world_size = 4
|
mock_tp_group.return_value.world_size = 4
|
||||||
mock_dp_group.return_value.world_size = 2
|
|
||||||
mock_pp_group.return_value.world_size = 2
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -50,7 +46,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
|||||||
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
||||||
mock_envs_ascend = MagicMock()
|
mock_envs_ascend = MagicMock()
|
||||||
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
||||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
|
|
||||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
||||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||||
|
|||||||
@@ -51,6 +51,12 @@ class AscendConfig:
|
|||||||
"weight_prefetch_config", {})
|
"weight_prefetch_config", {})
|
||||||
self.weight_prefetch_config = WeightPrefetchConfig(
|
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||||
weight_prefetch_config)
|
weight_prefetch_config)
|
||||||
|
self.layer_sharding = additional_config.get("layer_sharding", None)
|
||||||
|
logger.info_once(
|
||||||
|
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
|
||||||
|
"Note: This feature works optimally with FLASHCOMM2 and DSA-CP enabled; "
|
||||||
|
"using it without these features may result in significant performance degradation."
|
||||||
|
)
|
||||||
|
|
||||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
||||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||||
@@ -111,7 +117,7 @@ class AscendConfig:
|
|||||||
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
||||||
"SLO_limits_for_dynamic_batch", -1)
|
"SLO_limits_for_dynamic_batch", -1)
|
||||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||||
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
|
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
|
||||||
self, vllm_config)
|
self, vllm_config)
|
||||||
self.enable_npugraph_ex = additional_config.get(
|
self.enable_npugraph_ex = additional_config.get(
|
||||||
"enable_npugraph_ex", False)
|
"enable_npugraph_ex", False)
|
||||||
|
|||||||
@@ -31,15 +31,14 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
from vllm_ascend.compilation.acl_graph import (
|
from vllm_ascend.compilation.acl_graph import (
|
||||||
get_draft_graph_params, get_graph_params,
|
get_draft_graph_params, get_graph_params,
|
||||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||||
|
from vllm_ascend.ops.layer_shard_linear import (
|
||||||
|
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||||
|
reach_layer_for_shard_weight_series,
|
||||||
|
register_all_layers_to_shard_weight_series)
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||||
from vllm_ascend.ops.shared_weight_layer import (
|
|
||||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
|
||||||
reach_layer_for_shared_weight_series,
|
|
||||||
register_layer_to_shared_weight_series)
|
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
|
||||||
flashcomm2_o_shared_enabled, maybe_trans_nz,
|
|
||||||
weak_ref_tensors)
|
weak_ref_tensors)
|
||||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||||
|
|
||||||
@@ -734,18 +733,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.kv_b_proj = kwargs['kv_b_proj']
|
self.kv_b_proj = kwargs['kv_b_proj']
|
||||||
self.o_proj = kwargs['o_proj']
|
self.o_proj = kwargs['o_proj']
|
||||||
self.vllm_config = get_current_vllm_config()
|
self.vllm_config = get_current_vllm_config()
|
||||||
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
|
|
||||||
|
|
||||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
|
||||||
self.vllm_config, self.o_proj):
|
|
||||||
from vllm_ascend.distributed.parallel_state import \
|
|
||||||
get_shared_weight_group
|
|
||||||
register_layer_to_shared_weight_series(
|
|
||||||
series_name="o_proj",
|
|
||||||
group=get_shared_weight_group(),
|
|
||||||
layer=self.o_proj,
|
|
||||||
prefetch_step=1)
|
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||||
@@ -762,6 +749,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||||
|
|
||||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
self.layer_sharding_kwargs = []
|
||||||
|
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||||
|
if layer_name in kwargs:
|
||||||
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||||
|
)
|
||||||
|
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
||||||
@@ -833,9 +829,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# if mlapo, W_UK_T can't trans nz
|
# if mlapo, W_UK_T can't trans nz
|
||||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||||
|
|
||||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
self.vllm_config, self.o_proj):
|
if is_hidden_layer(layer):
|
||||||
post_process_after_loading_for_shared_weight_series(self.o_proj)
|
post_process_after_loading_for_shard_weight_series(layer)
|
||||||
|
|
||||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||||
@@ -1445,9 +1441,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
kv_no_split.contiguous(), need_gather_q_kv)
|
kv_no_split.contiguous(), need_gather_q_kv)
|
||||||
|
|
||||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
self.vllm_config, self.o_proj):
|
if is_hidden_layer(layer):
|
||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
|
|
||||||
decode_preprocess_res = None
|
decode_preprocess_res = None
|
||||||
prefill_preprocess_res = None
|
prefill_preprocess_res = None
|
||||||
@@ -1478,9 +1474,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
self.vllm_config, self.o_proj):
|
if is_hidden_layer(layer):
|
||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
|
|||||||
@@ -25,11 +25,11 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
|||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
trans_rope_weight, transdata,
|
trans_rope_weight, transdata,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
|
from vllm_ascend.ops.layer_shard_linear import (
|
||||||
|
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||||
|
reach_layer_for_shard_weight_series,
|
||||||
|
register_all_layers_to_shard_weight_series)
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||||
from vllm_ascend.ops.shared_weight_layer import (
|
|
||||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
|
||||||
reach_layer_for_shared_weight_series,
|
|
||||||
register_layer_to_shared_weight_series)
|
|
||||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
@@ -374,22 +374,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
self.local_num_heads = self.num_heads * self.tp_size
|
self.local_num_heads = self.num_heads * self.tp_size
|
||||||
|
|
||||||
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
|
||||||
self._replace_linear_class_for_sfa_cp()
|
self._replace_linear_class_for_sfa_cp()
|
||||||
from vllm_ascend.distributed.parallel_state import \
|
self.layer_sharding_kwargs = []
|
||||||
get_shared_weight_group
|
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
if layer_name in kwargs:
|
||||||
register_layer_to_shared_weight_series(
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||||
series_name="q_proj",
|
else:
|
||||||
group=get_shared_weight_group(),
|
logger.warning_once(
|
||||||
layer=self.q_proj,
|
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||||
prefetch_step=1)
|
)
|
||||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
register_all_layers_to_shard_weight_series(
|
||||||
register_layer_to_shared_weight_series(
|
self.layer_sharding_kwargs)
|
||||||
series_name="o_proj",
|
|
||||||
group=get_shared_weight_group(),
|
|
||||||
layer=self.o_proj,
|
|
||||||
prefetch_step=1)
|
|
||||||
|
|
||||||
# indexer param
|
# indexer param
|
||||||
self.n_head: int = self.indexer.n_head # 64
|
self.n_head: int = self.indexer.n_head # 64
|
||||||
@@ -434,14 +429,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
||||||
dispose_layer(self.kv_b_proj)
|
dispose_layer(self.kv_b_proj)
|
||||||
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
post_process_after_loading_for_shared_weight_series(
|
if is_hidden_layer(layer):
|
||||||
self.q_proj)
|
post_process_after_loading_for_shard_weight_series(layer)
|
||||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
|
||||||
post_process_after_loading_for_shared_weight_series(
|
|
||||||
self.o_proj)
|
|
||||||
|
|
||||||
if self.enable_mlapo:
|
if self.enable_mlapo:
|
||||||
quant_method = getattr(
|
quant_method = getattr(
|
||||||
@@ -751,10 +742,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
||||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
reach_layer_for_shared_weight_series(self.q_proj)
|
if is_hidden_layer(layer):
|
||||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
reach_layer_for_shard_weight_series(layer)
|
||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
has_prefill = attn_metadata.has_prefill
|
has_prefill = attn_metadata.has_prefill
|
||||||
cos = attn_metadata.cos
|
cos = attn_metadata.cos
|
||||||
@@ -809,10 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
slot_mapping_cp)
|
slot_mapping_cp)
|
||||||
|
|
||||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
reach_layer_for_shared_weight_series(self.q_proj)
|
if is_hidden_layer(layer):
|
||||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
reach_layer_for_shard_weight_series(layer)
|
||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
|
||||||
|
|
||||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||||
q_pe = self.rope_single(q_pe, cos, sin)
|
q_pe = self.rope_single(q_pe, cos, sin)
|
||||||
|
|||||||
@@ -2,14 +2,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
|
||||||
get_pp_group, get_tp_group,
|
|
||||||
get_world_group,
|
get_world_group,
|
||||||
init_model_parallel_group)
|
init_model_parallel_group)
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
|
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable
|
||||||
flashcomm2_o_shared_enabled)
|
|
||||||
|
|
||||||
# Currently, mc2 op need their own group coordinator.
|
# Currently, mc2 op need their own group coordinator.
|
||||||
_MC2: Optional[GroupCoordinator] = None
|
_MC2: Optional[GroupCoordinator] = None
|
||||||
@@ -25,8 +23,8 @@ _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
|||||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||||
_FC3_QUANT_X: Optional[GroupCoordinator] = None
|
_FC3_QUANT_X: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
# shared_weight across rank groups
|
# shard_weight across rank groups
|
||||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
_SHARD_WEIGHT: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
_P_TP: Optional[GroupCoordinator] = None
|
_P_TP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
@@ -37,7 +35,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
assert torch.distributed.is_initialized()
|
assert torch.distributed.is_initialized()
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
global_tp_size = parallel_config.tensor_parallel_size
|
global_tp_size = parallel_config.tensor_parallel_size
|
||||||
global_dp_size = parallel_config.data_parallel_size
|
global_dp_size = parallel_config.data_parallel_size
|
||||||
global_pp_size = parallel_config.pipeline_parallel_size
|
global_pp_size = parallel_config.pipeline_parallel_size
|
||||||
@@ -48,6 +45,14 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
all_ranks = torch.arange(world_size).reshape(
|
all_ranks = torch.arange(world_size).reshape(
|
||||||
-1, global_dp_size * parallel_config.prefill_context_parallel_size *
|
-1, global_dp_size * parallel_config.prefill_context_parallel_size *
|
||||||
global_tp_size)
|
global_tp_size)
|
||||||
|
#TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
|
||||||
|
vllm_all_ranks = torch.arange(world_size).reshape(
|
||||||
|
-1,
|
||||||
|
global_dp_size,
|
||||||
|
global_pp_size,
|
||||||
|
parallel_config.prefill_context_parallel_size,
|
||||||
|
global_tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||||
pd_head_ratio = get_ascend_config().pd_head_ratio
|
pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||||
@@ -148,38 +153,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
if mlp_tp_size > 0:
|
if mlp_tp_size > 0:
|
||||||
_MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp")
|
_MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp")
|
||||||
|
|
||||||
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
|
|
||||||
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
|
|
||||||
group_ranks = []
|
|
||||||
for pp_idx in range(global_pp_size):
|
|
||||||
group = []
|
|
||||||
for dp_idx in range(global_dp_size):
|
|
||||||
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
|
|
||||||
for i in range(global_tp_size):
|
|
||||||
global_rank = base + i
|
|
||||||
group.append(global_rank)
|
|
||||||
group_ranks.append(group)
|
|
||||||
|
|
||||||
return init_model_parallel_group(group_ranks,
|
|
||||||
get_world_group().local_rank,
|
|
||||||
backend,
|
|
||||||
group_name=group_name)
|
|
||||||
|
|
||||||
global _SHARED_WEIGHT
|
|
||||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
|
||||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
|
|
||||||
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
|
|
||||||
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
|
|
||||||
# TODO: Extract and unify the logic across different communication group.
|
# TODO: Extract and unify the logic across different communication group.
|
||||||
|
flashcomm2_otp_group_ranks = []
|
||||||
if flashcomm2_enable():
|
if flashcomm2_enable():
|
||||||
flashcomm2_otp_size = get_ascend_config(
|
flashcomm2_otp_size = get_ascend_config(
|
||||||
).flashcomm2_oproj_tensor_parallel_size
|
).flashcomm2_oproj_tensor_parallel_size
|
||||||
global_tp_size = get_tp_group().world_size
|
|
||||||
global_dp_size = get_dp_group().world_size
|
|
||||||
global_pp_size = get_pp_group().world_size
|
|
||||||
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
|
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||||
flashcomm2_otp_size)
|
flashcomm2_otp_size)
|
||||||
|
|
||||||
global _FLASHCOMM2_OTP
|
global _FLASHCOMM2_OTP
|
||||||
global _FLASHCOMM2_ODP
|
global _FLASHCOMM2_ODP
|
||||||
|
|
||||||
@@ -187,7 +167,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
_FLASHCOMM2_ODP = get_tp_group()
|
_FLASHCOMM2_ODP = get_tp_group()
|
||||||
|
|
||||||
if flashcomm2_otp_size > 1:
|
if flashcomm2_otp_size > 1:
|
||||||
otp_group_ranks = []
|
|
||||||
odp_group_ranks: list[list[int]] = [
|
odp_group_ranks: list[list[int]] = [
|
||||||
[] for _ in range(flashcomm2_otp_size * global_dp_size *
|
[] for _ in range(flashcomm2_otp_size * global_dp_size *
|
||||||
global_pp_size)
|
global_pp_size)
|
||||||
@@ -209,10 +188,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
odp_group_index = odp_base_index + j
|
odp_group_index = odp_base_index + j
|
||||||
odp_group_ranks[odp_group_index].append(
|
odp_group_ranks[odp_group_index].append(
|
||||||
global_rank)
|
global_rank)
|
||||||
otp_group_ranks.append(ranks)
|
flashcomm2_otp_group_ranks.append(ranks)
|
||||||
|
|
||||||
_FLASHCOMM2_OTP = init_model_parallel_group(
|
_FLASHCOMM2_OTP = init_model_parallel_group(
|
||||||
otp_group_ranks,
|
flashcomm2_otp_group_ranks,
|
||||||
get_world_group().local_rank,
|
get_world_group().local_rank,
|
||||||
backend,
|
backend,
|
||||||
group_name="flashcomm2_otp")
|
group_name="flashcomm2_otp")
|
||||||
@@ -222,12 +201,50 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
backend,
|
backend,
|
||||||
group_name="flashcomm2_odp")
|
group_name="flashcomm2_odp")
|
||||||
|
|
||||||
# Create shared weight group for flashcomm2 oproj
|
def create_shard_weight_group(
|
||||||
if flashcomm2_o_shared_enabled():
|
module_tp_group_ranks: None) -> GroupCoordinator:
|
||||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
# Argument module_tp_group_ranks: The module specific tensor parallel group.
|
||||||
if _SHARED_WEIGHT is None:
|
# There are three situations.
|
||||||
_SHARED_WEIGHT = _create_shared_weight_group(
|
# 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer.
|
||||||
"flashcomm2_o_shared")
|
# 2. If it is not None, and the module tp_group is same as the global tp_group.
|
||||||
|
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
|
||||||
|
group_ranks = []
|
||||||
|
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(
|
||||||
|
-1, global_pp_size)
|
||||||
|
if module_tp_group_ranks is None:
|
||||||
|
# If it is None, then the TP_size of this shard weight is 1.
|
||||||
|
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
|
||||||
|
group_ranks = [x.tolist() for x in shard_weight_group_ranks]
|
||||||
|
else:
|
||||||
|
# combine standard tp group and non-standard tp group to build shard_weight comm_group
|
||||||
|
module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1)
|
||||||
|
G = world_size // (global_pp_size * module_tp_group_ranks.size(1))
|
||||||
|
shard_weight_group_ranks = torch.stack(
|
||||||
|
[t.view(global_pp_size, G) for t in module_tp_tanspose_ranks],
|
||||||
|
dim=1)
|
||||||
|
group_ranks = shard_weight_group_ranks.view(-1, G).tolist()
|
||||||
|
return init_model_parallel_group(group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
group_name="shard_weight")
|
||||||
|
|
||||||
|
# Create shard weight group if enabled
|
||||||
|
if get_ascend_config().layer_sharding is not None:
|
||||||
|
global _SHARD_WEIGHT
|
||||||
|
if flashcomm2_enable():
|
||||||
|
if len(flashcomm2_otp_group_ranks) == 0:
|
||||||
|
FC2_group_ranks = None
|
||||||
|
else:
|
||||||
|
FC2_group_ranks = torch.tensor(
|
||||||
|
flashcomm2_otp_group_ranks).squeeze(0)
|
||||||
|
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
|
||||||
|
elif enable_dsa_cp():
|
||||||
|
# For dsa_cp, all shard layers are replicated.
|
||||||
|
_SHARD_WEIGHT = create_shard_weight_group(None)
|
||||||
|
else:
|
||||||
|
# For standard tp, use global tp group_ranks
|
||||||
|
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
|
||||||
|
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
|
||||||
|
|
||||||
if get_ascend_config().multistream_overlap_gate:
|
if get_ascend_config().multistream_overlap_gate:
|
||||||
global _FC3_QUANT_X
|
global _FC3_QUANT_X
|
||||||
@@ -280,11 +297,10 @@ def get_flashcomm2_odp_group() -> GroupCoordinator:
|
|||||||
return _FLASHCOMM2_ODP
|
return _FLASHCOMM2_ODP
|
||||||
|
|
||||||
|
|
||||||
def get_shared_weight_group() -> GroupCoordinator:
|
def get_shard_weight_group() -> GroupCoordinator:
|
||||||
assert _SHARED_WEIGHT is not None, (
|
assert _SHARD_WEIGHT is not None, (
|
||||||
"output shared weight parallel group for flashcomm2 is not initialized"
|
"output shard weight parallel group for flashcomm2 is not initialized")
|
||||||
)
|
return _SHARD_WEIGHT
|
||||||
return _SHARED_WEIGHT
|
|
||||||
|
|
||||||
|
|
||||||
def get_p_tp_group() -> GroupCoordinator:
|
def get_p_tp_group() -> GroupCoordinator:
|
||||||
@@ -341,10 +357,10 @@ def destroy_ascend_model_parallel():
|
|||||||
_FLASHCOMM2_ODP.destroy()
|
_FLASHCOMM2_ODP.destroy()
|
||||||
_FLASHCOMM2_ODP = None
|
_FLASHCOMM2_ODP = None
|
||||||
|
|
||||||
global _SHARED_WEIGHT
|
global _SHARD_WEIGHT
|
||||||
if _SHARED_WEIGHT:
|
if _SHARD_WEIGHT:
|
||||||
_SHARED_WEIGHT.destroy()
|
_SHARD_WEIGHT.destroy()
|
||||||
_SHARED_WEIGHT = None
|
_SHARD_WEIGHT = None
|
||||||
|
|
||||||
global _FC3_QUANT_X
|
global _FC3_QUANT_X
|
||||||
if _FC3_QUANT_X:
|
if _FC3_QUANT_X:
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import (get_dp_group,
|
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
|
||||||
get_fc3_quant_x_group,
|
|
||||||
get_p_tp_group)
|
get_p_tp_group)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -92,11 +92,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
||||||
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
||||||
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
||||||
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
|
|
||||||
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
|
|
||||||
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
|
|
||||||
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
|
|
||||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
|
|
||||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Optional
|
from functools import lru_cache
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -7,6 +8,8 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
|||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_shard_weight_group
|
||||||
|
|
||||||
|
|
||||||
def dispose_tensor(x: torch.Tensor):
|
def dispose_tensor(x: torch.Tensor):
|
||||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||||
@@ -26,17 +29,17 @@ class LayerMetadata:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SharedWindowMetadata:
|
class ShardWindowMetadata:
|
||||||
"""Metadata for a shared window.
|
"""Metadata for a shard window.
|
||||||
"""
|
"""
|
||||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
weight: torch.Tensor # The weight tensor to be shard by layers.
|
||||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SeriesMetadata:
|
class SeriesMetadata:
|
||||||
"""Metadata for a weight shared series.
|
"""Metadata for a weight shard series.
|
||||||
"""
|
"""
|
||||||
group: GroupCoordinator
|
group: GroupCoordinator
|
||||||
start_layer: int
|
start_layer: int
|
||||||
@@ -45,8 +48,8 @@ class SeriesMetadata:
|
|||||||
prefetch_step: int
|
prefetch_step: int
|
||||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||||
layers: list[LayerMetadata]
|
layers: list[LayerMetadata]
|
||||||
shared_windows: list[
|
shard_windows: list[
|
||||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||||
window_offset: int # The index of the window for the next coming layer.
|
window_offset: int # The index of the window for the next coming layer.
|
||||||
|
|
||||||
def is_source(self, layer_idx) -> bool:
|
def is_source(self, layer_idx) -> bool:
|
||||||
@@ -54,7 +57,7 @@ class SeriesMetadata:
|
|||||||
|
|
||||||
def post_process_after_loading(self):
|
def post_process_after_loading(self):
|
||||||
# This method only needs to be called once per series.
|
# This method only needs to be called once per series.
|
||||||
if self.shared_windows:
|
if self.shard_windows:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.layers.sort(key=lambda x: x.layer_idx)
|
self.layers.sort(key=lambda x: x.layer_idx)
|
||||||
@@ -83,8 +86,8 @@ class SeriesMetadata:
|
|||||||
step = layer_idx - self.start_layer
|
step = layer_idx - self.start_layer
|
||||||
if step < self.prefetch_step:
|
if step < self.prefetch_step:
|
||||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||||
self.shared_windows.append(
|
self.shard_windows.append(
|
||||||
SharedWindowMetadata(
|
ShardWindowMetadata(
|
||||||
weight=layer.weight.clone().detach(),
|
weight=layer.weight.clone().detach(),
|
||||||
data_layer_idx=layer_idx,
|
data_layer_idx=layer_idx,
|
||||||
work=None,
|
work=None,
|
||||||
@@ -92,12 +95,12 @@ class SeriesMetadata:
|
|||||||
layer.window_idx = step
|
layer.window_idx = step
|
||||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||||
if not is_source:
|
if not is_source:
|
||||||
layer.weight.set_(self.shared_windows[-1].weight)
|
layer.weight.set_(self.shard_windows[-1].weight)
|
||||||
else:
|
else:
|
||||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||||
if step == self.prefetch_step:
|
if step == self.prefetch_step:
|
||||||
self.shared_windows.append(
|
self.shard_windows.append(
|
||||||
SharedWindowMetadata(
|
ShardWindowMetadata(
|
||||||
weight=torch.empty_like(layer.weight),
|
weight=torch.empty_like(layer.weight),
|
||||||
data_layer_idx=-1,
|
data_layer_idx=-1,
|
||||||
work=None,
|
work=None,
|
||||||
@@ -115,7 +118,7 @@ class SeriesMetadata:
|
|||||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||||
# The index of the window to store the weight for the coming layer.
|
# The index of the window to store the weight for the coming layer.
|
||||||
next_layer.window_idx = self.window_offset
|
next_layer.window_idx = self.window_offset
|
||||||
window = self.shared_windows[next_layer.window_idx]
|
window = self.shard_windows[next_layer.window_idx]
|
||||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||||
if not self.is_source(next_layer_idx):
|
if not self.is_source(next_layer_idx):
|
||||||
next_layer.weight.set_(window.weight)
|
next_layer.weight.set_(window.weight)
|
||||||
@@ -133,10 +136,10 @@ class SeriesMetadata:
|
|||||||
|
|
||||||
def wait_weight(self, layer_idx: int):
|
def wait_weight(self, layer_idx: int):
|
||||||
# Find the asynchronous broadcast work and wait for it.
|
# Find the asynchronous broadcast work and wait for it.
|
||||||
assert self.shared_windows
|
assert self.shard_windows
|
||||||
window = self.shared_windows[self.layers[layer_idx -
|
window = self.shard_windows[self.layers[layer_idx -
|
||||||
self.start_layer].window_idx]
|
self.start_layer].window_idx]
|
||||||
# Make sure the data in the corresponding shared window is for the current layer.
|
# Make sure the data in the corresponding shard window is for the current layer.
|
||||||
assert window.data_layer_idx == layer_idx
|
assert window.data_layer_idx == layer_idx
|
||||||
if window.work is not None:
|
if window.work is not None:
|
||||||
window.work.wait()
|
window.work.wait()
|
||||||
@@ -168,13 +171,13 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Register linear layers into a shared storage series.
|
Register linear layers into a shard storage series.
|
||||||
|
|
||||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||||
|
|
||||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||||
|
|
||||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||||
|
|
||||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||||
- start_layer is the index of the first layer in the series (inclusive).
|
- start_layer is the index of the first layer in the series (inclusive).
|
||||||
@@ -182,7 +185,7 @@ Note: The layers are managed as a circular buffer. The index of the layer to pre
|
|||||||
- total_layers = end_layer - start_layer
|
- total_layers = end_layer - start_layer
|
||||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||||
|
|
||||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
series_name: This name identifies which series this layer belongs to.
|
series_name: This name identifies which series this layer belongs to.
|
||||||
@@ -192,7 +195,7 @@ Arguments:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def register_layer_to_shared_weight_series(
|
def register_layer_to_shard_weight_series(
|
||||||
series_name: str,
|
series_name: str,
|
||||||
group: GroupCoordinator,
|
group: GroupCoordinator,
|
||||||
layer: LinearBase,
|
layer: LinearBase,
|
||||||
@@ -208,7 +211,7 @@ def register_layer_to_shared_weight_series(
|
|||||||
prefetch_step=prefetch_step,
|
prefetch_step=prefetch_step,
|
||||||
dummy_weight=torch.empty_like(layer.weight),
|
dummy_weight=torch.empty_like(layer.weight),
|
||||||
layers=[],
|
layers=[],
|
||||||
shared_windows=[],
|
shard_windows=[],
|
||||||
window_offset=prefetch_step,
|
window_offset=prefetch_step,
|
||||||
)
|
)
|
||||||
series = _series_dict[series_name]
|
series = _series_dict[series_name]
|
||||||
@@ -236,17 +239,42 @@ def register_layer_to_shared_weight_series(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
def post_process_after_loading_for_shard_weight_series(layer: LinearBase):
|
||||||
ext = _layer_external_dict[id(layer)]
|
ext = _layer_external_dict[id(layer)]
|
||||||
ext.series.post_process_after_loading()
|
ext.series.post_process_after_loading()
|
||||||
|
|
||||||
|
|
||||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
def reach_layer_for_shard_weight_series(layer: LinearBase):
|
||||||
ext = _layer_external_dict[id(layer)]
|
ext = _layer_external_dict[id(layer)]
|
||||||
ext.series.reach_layer(ext.layer_idx)
|
ext.series.reach_layer(ext.layer_idx)
|
||||||
|
|
||||||
|
|
||||||
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
|
def wait_layer_for_shard_weight_series(layer: LinearBase):
|
||||||
num_hidden_layers = vllm_config.model_config.hf_text_config.num_hidden_layers
|
ext = _layer_external_dict[id(layer)]
|
||||||
|
ext.series.wait_weight(ext.layer_idx)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_current_model_num_hidden_layers() -> int:
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
return vllm_config.model_config.get_total_num_hidden_layers()
|
||||||
|
|
||||||
|
|
||||||
|
def is_hidden_layer(layer: LinearBase) -> bool:
|
||||||
|
num_hidden_layers = get_current_model_num_hidden_layers()
|
||||||
layer_idx = extract_layer_index(layer.prefix)
|
layer_idx = extract_layer_index(layer.prefix)
|
||||||
return layer_idx < num_hidden_layers
|
return layer_idx < num_hidden_layers
|
||||||
|
|
||||||
|
|
||||||
|
def register_all_layers_to_shard_weight_series(
|
||||||
|
layer_sharding: List[LinearBase], ):
|
||||||
|
for curr_layer in (layer_sharding or []):
|
||||||
|
if is_hidden_layer(curr_layer):
|
||||||
|
layer_name = curr_layer.prefix.split('.')[-1]
|
||||||
|
register_layer_to_shard_weight_series(
|
||||||
|
series_name=layer_name,
|
||||||
|
group=get_shard_weight_group(),
|
||||||
|
layer=curr_layer,
|
||||||
|
prefetch_step=1,
|
||||||
|
)
|
||||||
@@ -23,6 +23,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -1016,22 +1017,27 @@ def flashcomm2_enable() -> bool:
|
|||||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||||
|
|
||||||
|
|
||||||
def flashcomm2_o_shared_enabled() -> bool:
|
|
||||||
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
|
|
||||||
|
|
||||||
|
|
||||||
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
|
|
||||||
|
|
||||||
if not flashcomm2_enable():
|
if not flashcomm2_enable():
|
||||||
flashcomm2_oproj_shared = False
|
return 0
|
||||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
|
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
layer_sharding = ascend_config.layer_sharding or []
|
||||||
|
if layer_sharding:
|
||||||
|
if layer_sharding == ["o_proj"]:
|
||||||
|
logger.info_once(
|
||||||
|
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
|
||||||
|
f"Found invalid layer_sharding: {layer_sharding}")
|
||||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
||||||
@@ -1054,13 +1060,10 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
|||||||
)
|
)
|
||||||
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"FLASHCOMM2 primarily targets P-scenario deployments, "
|
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||||
"with additional support for hybrid deployment scenarios. "
|
)
|
||||||
"It is not applicable in D-scenario environments.")
|
|
||||||
if flashcomm2_oproj_shared:
|
|
||||||
logger.info("Enable FLASHCOMM2 with oproj_shared.")
|
|
||||||
|
|
||||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
return flashcomm2_oproj_tp_size
|
||||||
|
|
||||||
|
|
||||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||||
@@ -1160,4 +1163,20 @@ def singleton(cls):
|
|||||||
instances[cls] = cls(*args, **kwargs)
|
instances[cls] = cls(*args, **kwargs)
|
||||||
return instances[cls]
|
return instances[cls]
|
||||||
|
|
||||||
return get_instance
|
return get_instance
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_current_model_config():
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
return vllm_config.model_config
|
||||||
|
|
||||||
|
|
||||||
|
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def enable_dsa_cp() -> bool:
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||||
|
return is_ds_v32 and enable_sp()
|
||||||
|
|||||||
Reference in New Issue
Block a user