[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -49,6 +49,7 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. |
|
||||
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
|
||||
| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||
| `layer_sharding` | dict | `{}` | Configuration options for layer sharding linear |
|
||||
|
||||
The details of each configuration option are as follows:
|
||||
|
||||
|
||||
BIN
docs/source/user_guide/feature_guide/images/layer_sharding.png
Normal file
BIN
docs/source/user_guide/feature_guide/images/layer_sharding.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 156 KiB |
@@ -19,6 +19,7 @@ external_dp
|
||||
large_scale_ep
|
||||
ucm_deployment
|
||||
Fine_grained_TP
|
||||
layer_sharding
|
||||
speculative_decoding
|
||||
context_parallel
|
||||
:::
|
||||
|
||||
73
docs/source/user_guide/feature_guide/layer_sharding.md
Normal file
73
docs/source/user_guide/feature_guide/layer_sharding.md
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: Layer Sharding Guide
|
||||
---
|
||||
|
||||
# Overview
|
||||
|
||||
**Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights.
|
||||
|
||||
Instead of replicating all weights on every device, **Layer Shard Linear shards the weights of a "series" of such operators across the NPU devices in a communication group**:
|
||||
- The **i-th layer's linear weight** is stored **only on device `i % K`**, where `K` is the number of devices in the group.
|
||||
- Other devices hold a lightweight **shared dummy tensor** during initialization and fetch the real weight **on-demand via asynchronous broadcast** during the forward pass.
|
||||
|
||||
As illustrated in the figure below, this design enables broadcast to reach weights: while the current layer (e.g., MLA or MOE) is being computed, the system **asynchronously broadcasts the next layer's weight** in the background. Because the attention computation in the MLA module is sufficiently latency-bound, the weight transfer for `o_proj` is **fully overlapped with computation**, making the communication **latency-free from the perspective of end-to-end inference**.
|
||||
|
||||
This approach **preserves exact computational semantics** while **significantly reducing NPU memory footprint**, especially critical for:
|
||||
- Extremely deep architectures (e.g., DeepSeek-V3/R1 with 61 layers);
|
||||
- Models using **[DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702)** or **[FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188)**, where the full `O` (output) projection matrix must reside in memory per layer;
|
||||
- Scenarios where **attention computation latency fully overlaps** (hides) the communication cost of weight broadcasting.
|
||||
|
||||
---
|
||||
|
||||
## Flowchart
|
||||

|
||||
|
||||
> **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading.
|
||||
|
||||
---
|
||||
|
||||
# Getting Started
|
||||
|
||||
To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use:
|
||||
|
||||
```bash
|
||||
--additional-config '{
|
||||
"layer_sharding": ["o_proj", "q_b_proj"]
|
||||
}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Supported Scenarios
|
||||
|
||||
This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases:
|
||||
|
||||
## FlashComm2-enabled
|
||||
|
||||
When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices.
|
||||
|
||||
**Example configuration:**
|
||||
|
||||
```bash
|
||||
export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
|
||||
vllm serve \
|
||||
--model DeepSeek-V3/R1 \
|
||||
--additional-config '{
|
||||
"layer_sharding": ["o_proj"]
|
||||
}'
|
||||
```
|
||||
|
||||
## DSA-CP-enabled
|
||||
|
||||
With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory.
|
||||
|
||||
**Example configuration:**
|
||||
|
||||
```bash
|
||||
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
|
||||
vllm serve \
|
||||
--model DeepSeek-V3.2 \
|
||||
--additional-config '{
|
||||
"layer_sharding": ["q_b_proj", "o_proj"]
|
||||
}'
|
||||
```
|
||||
@@ -25,14 +25,10 @@ def mock_distributed():
|
||||
patch('torch.distributed.get_world_size', return_value=16), \
|
||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
|
||||
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group:
|
||||
mock_group.return_value.local_rank = 0
|
||||
mock_group.return_value.device_group = MagicMock()
|
||||
mock_tp_group.return_value.world_size = 4
|
||||
mock_dp_group.return_value.world_size = 2
|
||||
mock_pp_group.return_value.world_size = 2
|
||||
yield
|
||||
|
||||
|
||||
@@ -50,7 +46,6 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_vllm_config.kv_transfer_config.is_kv_producer = True
|
||||
mock_envs_ascend = MagicMock()
|
||||
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
|
||||
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
|
||||
@@ -51,6 +51,12 @@ class AscendConfig:
|
||||
"weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||
weight_prefetch_config)
|
||||
self.layer_sharding = additional_config.get("layer_sharding", None)
|
||||
logger.info_once(
|
||||
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
|
||||
"Note: This feature works optimally with FLASHCOMM2 and DSA-CP enabled; "
|
||||
"using it without these features may result in significant performance degradation."
|
||||
)
|
||||
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||
@@ -111,7 +117,7 @@ class AscendConfig:
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
||||
"SLO_limits_for_dynamic_batch", -1)
|
||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
|
||||
self, vllm_config)
|
||||
self.enable_npugraph_ex = additional_config.get(
|
||||
"enable_npugraph_ex", False)
|
||||
|
||||
@@ -31,15 +31,14 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
get_draft_graph_params, get_graph_params,
|
||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||
reach_layer_for_shard_weight_series,
|
||||
register_all_layers_to_shard_weight_series)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
|
||||
flashcomm2_o_shared_enabled, maybe_trans_nz,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
@@ -734,18 +733,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
@@ -762,6 +749,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||
if layer_name in kwargs:
|
||||
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
||||
@@ -833,9 +829,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
@@ -1445,9 +1441,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
@@ -1478,9 +1474,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
@@ -25,11 +25,11 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.ops.layer_shard_linear import (
|
||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||
reach_layer_for_shard_weight_series,
|
||||
register_all_layers_to_shard_weight_series)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
|
||||
reach_layer_for_shared_weight_series,
|
||||
register_layer_to_shared_weight_series)
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
@@ -374,22 +374,17 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if self.enable_sfa_cp:
|
||||
self.local_num_heads = self.num_heads * self.tp_size
|
||||
|
||||
# TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
|
||||
self._replace_linear_class_for_sfa_cp()
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
get_shared_weight_group
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="q_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.q_proj,
|
||||
prefetch_step=1)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
register_layer_to_shared_weight_series(
|
||||
series_name="o_proj",
|
||||
group=get_shared_weight_group(),
|
||||
layer=self.o_proj,
|
||||
prefetch_step=1)
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||
if layer_name in kwargs:
|
||||
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(
|
||||
self.layer_sharding_kwargs)
|
||||
|
||||
# indexer param
|
||||
self.n_head: int = self.indexer.n_head # 64
|
||||
@@ -434,14 +429,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
||||
dispose_layer(self.kv_b_proj)
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
post_process_after_loading_for_shared_weight_series(
|
||||
self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
|
||||
if self.enable_mlapo:
|
||||
quant_method = getattr(
|
||||
@@ -751,10 +742,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
return output.fill_(0)
|
||||
has_prefill = attn_metadata.has_prefill
|
||||
cos = attn_metadata.cos
|
||||
@@ -809,10 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
slot_mapping_cp)
|
||||
|
||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
||||
if is_hidden_layer(self.vllm_config, self.q_proj):
|
||||
reach_layer_for_shared_weight_series(self.q_proj)
|
||||
if is_hidden_layer(self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
for layer in (self.layer_sharding_kwargs or []):
|
||||
if is_hidden_layer(layer):
|
||||
reach_layer_for_shard_weight_series(layer)
|
||||
|
||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
|
||||
@@ -2,14 +2,12 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||
get_pp_group, get_tp_group,
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
|
||||
get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
|
||||
flashcomm2_o_shared_enabled)
|
||||
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
@@ -25,8 +23,8 @@ _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||
_FC3_QUANT_X: Optional[GroupCoordinator] = None
|
||||
|
||||
# shared_weight across rank groups
|
||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
||||
# shard_weight across rank groups
|
||||
_SHARD_WEIGHT: Optional[GroupCoordinator] = None
|
||||
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
@@ -37,7 +35,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
backend = torch.distributed.get_backend(get_world_group().device_group)
|
||||
vllm_config = get_current_vllm_config()
|
||||
global_tp_size = parallel_config.tensor_parallel_size
|
||||
global_dp_size = parallel_config.data_parallel_size
|
||||
global_pp_size = parallel_config.pipeline_parallel_size
|
||||
@@ -48,6 +45,14 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, global_dp_size * parallel_config.prefill_context_parallel_size *
|
||||
global_tp_size)
|
||||
#TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
|
||||
vllm_all_ranks = torch.arange(world_size).reshape(
|
||||
-1,
|
||||
global_dp_size,
|
||||
global_pp_size,
|
||||
parallel_config.prefill_context_parallel_size,
|
||||
global_tp_size,
|
||||
)
|
||||
|
||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
@@ -148,38 +153,13 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
if mlp_tp_size > 0:
|
||||
_MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp")
|
||||
|
||||
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
|
||||
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
|
||||
group_ranks = []
|
||||
for pp_idx in range(global_pp_size):
|
||||
group = []
|
||||
for dp_idx in range(global_dp_size):
|
||||
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
|
||||
for i in range(global_tp_size):
|
||||
global_rank = base + i
|
||||
group.append(global_rank)
|
||||
group_ranks.append(group)
|
||||
|
||||
return init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name=group_name)
|
||||
|
||||
global _SHARED_WEIGHT
|
||||
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
|
||||
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
|
||||
# TODO: Extract and unify the logic across different communication group.
|
||||
flashcomm2_otp_group_ranks = []
|
||||
if flashcomm2_enable():
|
||||
flashcomm2_otp_size = get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size
|
||||
global_tp_size = get_tp_group().world_size
|
||||
global_dp_size = get_dp_group().world_size
|
||||
global_pp_size = get_pp_group().world_size
|
||||
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||
flashcomm2_otp_size)
|
||||
|
||||
global _FLASHCOMM2_OTP
|
||||
global _FLASHCOMM2_ODP
|
||||
|
||||
@@ -187,7 +167,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
_FLASHCOMM2_ODP = get_tp_group()
|
||||
|
||||
if flashcomm2_otp_size > 1:
|
||||
otp_group_ranks = []
|
||||
odp_group_ranks: list[list[int]] = [
|
||||
[] for _ in range(flashcomm2_otp_size * global_dp_size *
|
||||
global_pp_size)
|
||||
@@ -209,10 +188,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
odp_group_index = odp_base_index + j
|
||||
odp_group_ranks[odp_group_index].append(
|
||||
global_rank)
|
||||
otp_group_ranks.append(ranks)
|
||||
flashcomm2_otp_group_ranks.append(ranks)
|
||||
|
||||
_FLASHCOMM2_OTP = init_model_parallel_group(
|
||||
otp_group_ranks,
|
||||
flashcomm2_otp_group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="flashcomm2_otp")
|
||||
@@ -222,12 +201,50 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
backend,
|
||||
group_name="flashcomm2_odp")
|
||||
|
||||
# Create shared weight group for flashcomm2 oproj
|
||||
if flashcomm2_o_shared_enabled():
|
||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
||||
if _SHARED_WEIGHT is None:
|
||||
_SHARED_WEIGHT = _create_shared_weight_group(
|
||||
"flashcomm2_o_shared")
|
||||
def create_shard_weight_group(
|
||||
module_tp_group_ranks: None) -> GroupCoordinator:
|
||||
# Argument module_tp_group_ranks: The module specific tensor parallel group.
|
||||
# There are three situations.
|
||||
# 1. If it is None, then the TP_size of the specific module is 1 and is replicated linear layer.
|
||||
# 2. If it is not None, and the module tp_group is same as the global tp_group.
|
||||
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
|
||||
group_ranks = []
|
||||
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(
|
||||
-1, global_pp_size)
|
||||
if module_tp_group_ranks is None:
|
||||
# If it is None, then the TP_size of this shard weight is 1.
|
||||
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
|
||||
group_ranks = [x.tolist() for x in shard_weight_group_ranks]
|
||||
else:
|
||||
# combine standard tp group and non-standard tp group to build shard_weight comm_group
|
||||
module_tp_tanspose_ranks = module_tp_group_ranks.transpose(0, 1)
|
||||
G = world_size // (global_pp_size * module_tp_group_ranks.size(1))
|
||||
shard_weight_group_ranks = torch.stack(
|
||||
[t.view(global_pp_size, G) for t in module_tp_tanspose_ranks],
|
||||
dim=1)
|
||||
group_ranks = shard_weight_group_ranks.view(-1, G).tolist()
|
||||
return init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="shard_weight")
|
||||
|
||||
# Create shard weight group if enabled
|
||||
if get_ascend_config().layer_sharding is not None:
|
||||
global _SHARD_WEIGHT
|
||||
if flashcomm2_enable():
|
||||
if len(flashcomm2_otp_group_ranks) == 0:
|
||||
FC2_group_ranks = None
|
||||
else:
|
||||
FC2_group_ranks = torch.tensor(
|
||||
flashcomm2_otp_group_ranks).squeeze(0)
|
||||
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
|
||||
elif enable_dsa_cp():
|
||||
# For dsa_cp, all shard layers are replicated.
|
||||
_SHARD_WEIGHT = create_shard_weight_group(None)
|
||||
else:
|
||||
# For standard tp, use global tp group_ranks
|
||||
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
|
||||
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
|
||||
|
||||
if get_ascend_config().multistream_overlap_gate:
|
||||
global _FC3_QUANT_X
|
||||
@@ -280,11 +297,10 @@ def get_flashcomm2_odp_group() -> GroupCoordinator:
|
||||
return _FLASHCOMM2_ODP
|
||||
|
||||
|
||||
def get_shared_weight_group() -> GroupCoordinator:
|
||||
assert _SHARED_WEIGHT is not None, (
|
||||
"output shared weight parallel group for flashcomm2 is not initialized"
|
||||
)
|
||||
return _SHARED_WEIGHT
|
||||
def get_shard_weight_group() -> GroupCoordinator:
|
||||
assert _SHARD_WEIGHT is not None, (
|
||||
"output shard weight parallel group for flashcomm2 is not initialized")
|
||||
return _SHARD_WEIGHT
|
||||
|
||||
|
||||
def get_p_tp_group() -> GroupCoordinator:
|
||||
@@ -341,10 +357,10 @@ def destroy_ascend_model_parallel():
|
||||
_FLASHCOMM2_ODP.destroy()
|
||||
_FLASHCOMM2_ODP = None
|
||||
|
||||
global _SHARED_WEIGHT
|
||||
if _SHARED_WEIGHT:
|
||||
_SHARED_WEIGHT.destroy()
|
||||
_SHARED_WEIGHT = None
|
||||
global _SHARD_WEIGHT
|
||||
if _SHARD_WEIGHT:
|
||||
_SHARD_WEIGHT.destroy()
|
||||
_SHARD_WEIGHT = None
|
||||
|
||||
global _FC3_QUANT_X
|
||||
if _FC3_QUANT_X:
|
||||
|
||||
@@ -2,10 +2,10 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_dp_group,
|
||||
get_fc3_quant_x_group,
|
||||
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
|
||||
get_p_tp_group)
|
||||
|
||||
|
||||
|
||||
@@ -92,11 +92,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
|
||||
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
|
||||
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
|
||||
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
|
||||
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
from functools import lru_cache
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -7,6 +8,8 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_shard_weight_group
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
@@ -26,17 +29,17 @@ class LayerMetadata:
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
class ShardWindowMetadata:
|
||||
"""Metadata for a shard window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
weight: torch.Tensor # The weight tensor to be shard by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""Metadata for a weight shard series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
@@ -45,8 +48,8 @@ class SeriesMetadata:
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
shard_windows: list[
|
||||
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
@@ -54,7 +57,7 @@ class SeriesMetadata:
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
if self.shard_windows:
|
||||
return
|
||||
|
||||
self.layers.sort(key=lambda x: x.layer_idx)
|
||||
@@ -83,8 +86,8 @@ class SeriesMetadata:
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
self.shard_windows.append(
|
||||
ShardWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
@@ -92,12 +95,12 @@ class SeriesMetadata:
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
layer.weight.set_(self.shard_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
self.shard_windows.append(
|
||||
ShardWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
@@ -115,7 +118,7 @@ class SeriesMetadata:
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
window = self.shard_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
@@ -133,10 +136,10 @@ class SeriesMetadata:
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert self.shard_windows
|
||||
window = self.shard_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shard window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
@@ -168,13 +171,13 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
Register linear layers into a shard storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- start_layer is the index of the first layer in the series (inclusive).
|
||||
@@ -182,7 +185,7 @@ Note: The layers are managed as a circular buffer. The index of the layer to pre
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
@@ -192,7 +195,7 @@ Arguments:
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
def register_layer_to_shard_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
layer: LinearBase,
|
||||
@@ -208,7 +211,7 @@ def register_layer_to_shared_weight_series(
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[],
|
||||
shared_windows=[],
|
||||
shard_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
@@ -236,17 +239,42 @@ def register_layer_to_shared_weight_series(
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
def post_process_after_loading_for_shard_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
def reach_layer_for_shard_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
|
||||
|
||||
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
|
||||
num_hidden_layers = vllm_config.model_config.hf_text_config.num_hidden_layers
|
||||
def wait_layer_for_shard_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.wait_weight(ext.layer_idx)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_model_num_hidden_layers() -> int:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.model_config.get_total_num_hidden_layers()
|
||||
|
||||
|
||||
def is_hidden_layer(layer: LinearBase) -> bool:
|
||||
num_hidden_layers = get_current_model_num_hidden_layers()
|
||||
layer_idx = extract_layer_index(layer.prefix)
|
||||
return layer_idx < num_hidden_layers
|
||||
|
||||
|
||||
def register_all_layers_to_shard_weight_series(
|
||||
layer_sharding: List[LinearBase], ):
|
||||
for curr_layer in (layer_sharding or []):
|
||||
if is_hidden_layer(curr_layer):
|
||||
layer_name = curr_layer.prefix.split('.')[-1]
|
||||
register_layer_to_shard_weight_series(
|
||||
series_name=layer_name,
|
||||
group=get_shard_weight_group(),
|
||||
layer=curr_layer,
|
||||
prefetch_step=1,
|
||||
)
|
||||
@@ -23,6 +23,7 @@ import math
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
@@ -1016,22 +1017,27 @@ def flashcomm2_enable() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
|
||||
|
||||
|
||||
def flashcomm2_o_shared_enabled() -> bool:
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
|
||||
|
||||
|
||||
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
|
||||
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
|
||||
|
||||
if not flashcomm2_enable():
|
||||
flashcomm2_oproj_shared = False
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
|
||||
)
|
||||
|
||||
layer_sharding = ascend_config.layer_sharding or []
|
||||
if layer_sharding:
|
||||
if layer_sharding == ["o_proj"]:
|
||||
logger.info_once(
|
||||
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
|
||||
f"Found invalid layer_sharding: {layer_sharding}")
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
||||
@@ -1054,13 +1060,10 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
)
|
||||
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, "
|
||||
"with additional support for hybrid deployment scenarios. "
|
||||
"It is not applicable in D-scenario environments.")
|
||||
if flashcomm2_oproj_shared:
|
||||
logger.info("Enable FLASHCOMM2 with oproj_shared.")
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||
)
|
||||
|
||||
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
@@ -1160,4 +1163,20 @@ def singleton(cls):
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
||||
return get_instance
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_model_config():
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.model_config
|
||||
|
||||
|
||||
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
return is_ds_v32 and enable_sp()
|
||||
|
||||
Reference in New Issue
Block a user