From a5163c8c369e046c53e6127bbbcee23392cec069 Mon Sep 17 00:00:00 2001 From: AlvisGong Date: Sat, 6 Dec 2025 19:46:41 +0800 Subject: [PATCH] [Feat]enable sfa cp for dsv3.2 (#4702) ### What this PR does / why we need it? RFC: https://github.com/vllm-project/vllm/issues/30055 ### How was this patch tested? 1. enable flashcommon1 export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 2. enable sfa-cp --additional-config '{ "enable_sfa_cp": true }' \ - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: AlvisGong Co-authored-by: clrs97 <524936896@qq.com> Co-authored-by: zzhx1 Co-authored-by: hwhaokun Co-authored-by: wangxiyuan --- vllm_ascend/attention/sfa_v1.py | 327 ++++++++++++++++++---- vllm_ascend/distributed/parallel_state.py | 27 +- vllm_ascend/ops/shared_weight_layer.py | 252 +++++++++++++++++ vllm_ascend/utils.py | 12 + 4 files changed, 564 insertions(+), 54 deletions(-) create mode 100644 vllm_ascend/ops/shared_weight_layer.py diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 24306a01..00be33f2 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -5,9 +5,9 @@ import torch import torch_npu from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import (LinearBase, +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear, UnquantizedLinearMethod) from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -17,10 +17,15 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, wait_for_kv_layer_from_connector) +from vllm_ascend.ops.shared_weight_layer import ( + is_hidden_layer, post_process_after_loading_for_shared_weight_series, + reach_layer_for_shared_weight_series, + register_layer_to_shared_weight_series) from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_enable_nz) + _round_up, dispose_layer, enable_sp, + is_enable_nz, replace_layer) from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -49,6 +54,20 @@ class AscendSFABackend(AttentionBackend): return AscendSFAImpl +@dataclass +class SfaCpContext: + num_tokens: int + num_tokens_pad: int + local_start: int + local_end: int + local_end_with_pad: int + pad_size: int + local_pad_size: int + slot_mapping_cp: torch.Tensor + actual_seq_lengths_query: torch.Tensor + actual_seq_lengths_key: torch.Tensor + + @dataclass class AscendSFAMetadata: """Metadata for MLACommon. @@ -79,6 +98,7 @@ class AscendSFAMetadata: attn_mask: torch.Tensor = None # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + sfa_cp_context: Optional[SfaCpContext] = None M = TypeVar("M", bound=AscendSFAMetadata) @@ -122,6 +142,9 @@ class AscendSFAMetadataBuilder: self.cos_cache = None self.sin_cache = None + self.enable_sfa_cp = enable_sp() and \ + hasattr(self.model_config.hf_config, "index_topk") + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # No need to reorder for Ascend SFA @@ -171,6 +194,64 @@ class AscendSFAMetadataBuilder: sin = self.sin_cache[input_positions].unsqueeze( # type: ignore 1).unsqueeze(2) + sfa_cp_context = None + if self.enable_sfa_cp: + global_tp_size = get_tp_group().world_size + num_tokens = num_actual_tokens + num_tokens_pad = _round_up(num_actual_tokens, global_tp_size) + num_tokens_per_device = num_tokens_pad // global_tp_size + pad_size = num_tokens_pad - num_tokens + local_start = get_tp_group().rank_in_group * num_tokens_per_device + local_end_with_pad = local_start + num_tokens_per_device + local_end = min(local_end_with_pad, num_actual_tokens) + local_pad_size = local_end_with_pad - local_end + + if pad_size > 0: + cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size)) + sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size)) + slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size), + value=-1) + cos = cos[local_start:local_end_with_pad] + sin = sin[local_start:local_end_with_pad] + slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] + + actual_seq_lengths_query = torch.empty_like(cum_query_lens) + actual_seq_lengths_key = torch.empty_like(seq_lens) + num_segs = cum_query_lens.shape[0] + last_token = 0 + cum = 0 + for i in range(0, num_segs): + global_start = last_token + global_end = cum_query_lens[i].item() + last_token = global_end + + local_start = max(global_start, local_start) + local_end = min(global_end, local_end_with_pad) + num_local_tokens = local_end - local_start + + if num_local_tokens > 0: + cum += num_local_tokens + actual_seq_lengths_query[i] = cum + + offset = global_end - local_end + actual_seq_lengths_key[i] = seq_lens[i].item() - offset + else: + actual_seq_lengths_query[i] = cum + actual_seq_lengths_key[i] = 0 + + sfa_cp_context = SfaCpContext( + num_tokens=num_tokens, + num_tokens_pad=num_tokens_pad, + local_start=local_start, + local_end=local_end, + local_end_with_pad=local_end_with_pad, + pad_size=pad_size, + local_pad_size=local_pad_size, + slot_mapping_cp=slot_mapping_cp, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + ) + return self.metadata_cls( # type: ignore has_prefill=has_prefill, num_input_tokens=common_attn_metadata.num_input_tokens, @@ -183,7 +264,8 @@ class AscendSFAMetadataBuilder: attn_state=common_attn_metadata.attn_state, block_tables=block_table, sin=sin, - cos=cos) + cos=cos, + sfa_cp_context=sfa_cp_context) def build_for_graph_capture( self, @@ -251,6 +333,7 @@ class AscendSFAImpl(MLAAttentionImpl): self.q_a_layernorm = kwargs.get('q_a_layernorm', None) self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group self.num_heads_per_rank = self.num_heads // self.tp_size self.q_b_proj = kwargs['q_b_proj'] @@ -258,8 +341,32 @@ class AscendSFAImpl(MLAAttentionImpl): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - + self.vllm_config = get_current_vllm_config() assert self.indexer is not None, "Indexer is required for DSA." + + self.enable_sfa_cp = enable_sp() + self.local_num_heads = self.num_heads + + if self.enable_sfa_cp: + self.local_num_heads = self.num_heads * self.tp_size + + #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 + self._replace_linear_class_for_sfa_cp() + from vllm_ascend.distributed.parallel_state import \ + get_shared_weight_group + if is_hidden_layer(self.vllm_config, self.q_proj): + register_layer_to_shared_weight_series( + series_name="q_proj", + group=get_shared_weight_group(), + layer=self.q_proj, + prefetch_step=1) + if is_hidden_layer(self.vllm_config, self.o_proj): + register_layer_to_shared_weight_series( + series_name="o_proj", + group=get_shared_weight_group(), + layer=self.o_proj, + prefetch_step=1) + # indexer param self.n_head: int = self.indexer.n_head # 64 self.head_dim: int = self.indexer.head_dim # 128 @@ -306,16 +413,16 @@ class AscendSFAImpl(MLAAttentionImpl): # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + self.kv_lora_rank, self.local_num_heads * + (self.qk_nope_head_dim + self.v_head_dim)), ( f"{kv_b_proj_weight.shape=}, " f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " + f"{self.local_num_heads=}, " f"{self.qk_nope_head_dim=}, " f"{self.v_head_dim=}") kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, - self.num_heads, + self.local_num_heads, self.qk_nope_head_dim + self.v_head_dim, ) @@ -336,29 +443,42 @@ class AscendSFAImpl(MLAAttentionImpl): # Waiting for BMM NZ support # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory + dispose_layer(self.kv_b_proj) + + if self.enable_sfa_cp: + if is_hidden_layer(self.vllm_config, self.q_proj): + post_process_after_loading_for_shared_weight_series( + self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + post_process_after_loading_for_shared_weight_series( + self.o_proj) def _v_up_proj(self, x): if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536: - x = x.view(-1, self.num_heads, self.kv_lora_rank) + x = x.view(-1, self.local_num_heads, self.kv_lora_rank) x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_x1=[1, 0, 2], perm_x2=[0, 1, 2], perm_y=[1, 0, 2]) - x = x.reshape(-1, self.num_heads * self.v_head_dim) + x = x.reshape(-1, self.local_num_heads * self.v_head_dim) else: # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + x = x.view(-1, self.local_num_heads, + self.kv_lora_rank).transpose(0, 1) # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + x = x.transpose(0, + 1).reshape(-1, + self.local_num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ + .view(-1, self.local_num_heads, self.qk_head_dim)\ .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) @@ -375,6 +495,7 @@ class AscendSFAImpl(MLAAttentionImpl): sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, + slots_cp: Optional[torch.Tensor], ): B = kv_no_split.shape[0] N = self.num_kv_heads @@ -383,18 +504,44 @@ class AscendSFAImpl(MLAAttentionImpl): kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - ) - return k_pe, k_nope + + if self.enable_sfa_cp: + assert slots_cp is not None + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots_cp.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, + ) + #TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 + k_pe = get_tp_group().all_gather(k_pe, 0) + k_nope = get_tp_group().all_gather(k_nope, 0) + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_( + kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1), + k_nope.view(-1, k_nope.shape[-1])) + torch_npu.npu_scatter_nd_update_( + kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1), + k_pe.view(-1, k_pe.shape[-1])) + else: + torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) def rope_single( self, @@ -420,10 +567,20 @@ class AscendSFAImpl(MLAAttentionImpl): assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. + if self.enable_sfa_cp: + from vllm.forward_context import get_forward_context + if not get_forward_context().in_profile_run: + if is_hidden_layer(self.vllm_config, self.q_proj): + reach_layer_for_shared_weight_series(self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) + return output.fill_(0) has_prefill = attn_metadata.has_prefill num_actual_tokens = attn_metadata.num_actual_tokens hidden_states = hidden_states[:num_actual_tokens] + if self.enable_sfa_cp: + need_gather_q_kv = False # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_tokens] @@ -439,38 +596,61 @@ class AscendSFAImpl(MLAAttentionImpl): q_c = self.q_a_layernorm(q_c) # Process for Flash Comm V1 - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - q_c.contiguous(), need_gather_q_kv) - kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv) + if need_gather_q_kv: + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + q_c.contiguous(), need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split.contiguous(), need_gather_q_kv) if has_prefill: wait_for_kv_layer_from_connector(layer_name) + cos = attn_metadata.cos + sin = attn_metadata.sin slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens] - ql_nope, q_pe = \ - self._q_proj_and_k_up_proj(q_c) - q_pe = self.rope_single(q_pe, attn_metadata.cos, attn_metadata.sin) - k_pe, k_nope = self.exec_kv(kv_no_split, attn_metadata.cos, - attn_metadata.sin, kv_cache, slot_mapping) + slot_mapping_cp = None + actual_seq_lengths_query = attn_metadata.cum_query_lens + actual_seq_lengths_key = attn_metadata.seq_lens + if self.enable_sfa_cp: + assert attn_metadata.sfa_cp_context is not None + slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp + actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query + actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key - topk_indices = self.indexer_select(x=hidden_states, - qr=q_c, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - need_gather_q_kv=need_gather_q_kv) + self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, + slot_mapping_cp) + + if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: + if is_hidden_layer(self.vllm_config, self.q_proj): + reach_layer_for_shared_weight_series(self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) + + ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) + q_pe = self.rope_single(q_pe, cos, sin) + + topk_indices = self.indexer_select( + x=hidden_states, + qr=q_c, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + need_gather_q_kv=need_gather_q_kv) attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, - key=k_nope, - value=k_nope, + key=kv_cache[0], + value=kv_cache[0], sparse_indices=topk_indices, scale_value=self.scale, sparse_block_size=1, block_table=attn_metadata.block_tables, - actual_seq_lengths_query=attn_metadata.cum_query_lens, - actual_seq_lengths_kv=attn_metadata.seq_lens, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_key, query_rope=q_pe, - key_rope=k_pe, + key_rope=kv_cache[1], layout_query="TND", layout_kv="PA_BSND", sparse_mode=3, @@ -489,11 +669,12 @@ class AscendSFAImpl(MLAAttentionImpl): qr: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, + cos: torch.Tensor, + sin: torch.Tensor, + actual_seq_lengths_query: torch.Tensor, + actual_seq_lengths_key: torch.Tensor, need_gather_q_kv: bool = False, ): - cos = attn_metadata.cos - sin = attn_metadata.sin - # q process in new stream q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] @@ -539,6 +720,9 @@ class AscendSFAImpl(MLAAttentionImpl): k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + if self.enable_sfa_cp: + k = get_tp_group().all_gather(k, 0) + if kv_cache is not None: torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view( @@ -551,18 +735,55 @@ class AscendSFAImpl(MLAAttentionImpl): weights, need_gather_q_kv) block_table = attn_metadata.block_tables - seq_lens = attn_metadata.seq_lens - cum_query_lens = attn_metadata.cum_query_lens topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q, key=kv_cache[2], weights=weights, - actual_seq_lengths_query=cum_query_lens, - actual_seq_lengths_key=seq_lens, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, block_table=block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, sparse_mode=3) return topk_indices + + def _replace_linear_class_for_sfa_cp(self): + + vllm_config = get_current_vllm_config() + # Dispose tensor from the original q_proj + dispose_layer(self.q_proj) + # Construct the new q_proj using ReplicatedLinear + new_q_proj = ReplicatedLinear(self.q_lora_rank, + self.local_num_heads * self.qk_head_dim, + bias=False, + quant_config=vllm_config.quant_config, + prefix=self.q_proj.prefix) + # Replace the q_proj with the new one + replace_layer(self.q_proj, new_q_proj) + + # Dispose tensor from the original kv_b_proj + dispose_layer(self.kv_b_proj) + # Construct the new kv_b_proj using ReplicatedLinear + new_kv_b_proj = ReplicatedLinear( + self.kv_lora_rank, + self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=vllm_config.quant_config, + prefix=self.kv_b_proj.prefix) + # Replace the kv_b_proj with the new one + replace_layer(self.kv_b_proj, new_kv_b_proj) + + # Dispose tensor from the original o_proj + dispose_layer(self.o_proj) + # Construct the new o_proj using ReplicatedLinear + config = vllm_config.model_config.hf_config + new_o_proj = ReplicatedLinear(config.num_attention_heads * + config.v_head_dim, + config.hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=self.o_proj.prefix) + # Replace the o_proj with the new one + replace_layer(self.o_proj, new_o_proj) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 96d403f9..b6979e58 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -9,7 +9,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import flashcomm2_enable +from vllm_ascend.utils import enable_sp, flashcomm2_enable # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -19,6 +19,7 @@ _LMTP: Optional[GroupCoordinator] = None _P_TP: Optional[GroupCoordinator] = None _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None _FLASHCOMM2_ODP: Optional[GroupCoordinator] = None +_SHARED_WEIGHT: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: @@ -48,6 +49,13 @@ def get_flashcomm2_odp_group() -> GroupCoordinator: return _FLASHCOMM2_ODP +def get_shared_weight_group() -> GroupCoordinator: + assert _SHARED_WEIGHT is not None, ( + "output shared weight parallel group for flashcomm2 is not initialized" + ) + return _SHARED_WEIGHT + + def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") return _MLP_TP @@ -226,6 +234,18 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): backend, group_name="flashcomm2_odp") + vllm_config = get_current_vllm_config() + # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 + is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") + if enable_sp() and is_ds_v32: + global _SHARED_WEIGHT + group_ranks = [list(range(torch.distributed.get_world_size()))] + _SHARED_WEIGHT = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="CP_shared_weight") + def get_mlp_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" @@ -274,3 +294,8 @@ def destroy_ascend_model_parallel(): ).flashcomm2_oproj_tensor_parallel_size != 1: _FLASHCOMM2_ODP.destroy() _FLASHCOMM2_ODP = None + + global _SHARED_WEIGHT + if _SHARED_WEIGHT: + _SHARED_WEIGHT.destroy() + _SHARED_WEIGHT = None diff --git a/vllm_ascend/ops/shared_weight_layer.py b/vllm_ascend/ops/shared_weight_layer.py new file mode 100644 index 00000000..99e92439 --- /dev/null +++ b/vllm_ascend/ops/shared_weight_layer.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.models.utils import extract_layer_index + + +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty([], device=x.device, dtype=x.dtype)) + + +@dataclass +class LayerMetadata: + """Metadata for a layer. + """ + layer_idx: int # The index of the layer. + layer: LinearBase # The layer object. + post_method: Callable[[ + torch.nn.Module + ], None] # The `process_weights_after_loading` method from the quant method. + weight: torch.Tensor # The weight tensor. + window_idx: int # The index of the window. + + +@dataclass +class SharedWindowMetadata: + """Metadata for a shared window. + """ + weight: torch.Tensor # The weight tensor to be shared by layers. + data_layer_idx: int # The index of the layer this window's weight is equal to. + work: Optional[torch.distributed.Work] # The asynchronous broadcast work. + + +@dataclass +class SeriesMetadata: + """Metadata for a weight shared series. + """ + group: GroupCoordinator + start_layer: int + end_layer: int + num_layers: int + prefetch_step: int + dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. + layers: list[LayerMetadata] + shared_windows: list[ + SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. + window_offset: int # The index of the window for the next coming layer. + + def is_source(self, layer_idx) -> bool: + return layer_idx % self.group.world_size == self.group.rank_in_group + + def post_process_after_loading(self): + # This method only needs to be called once per series. + if self.shared_windows: + return + + self.layers.sort(key=lambda x: x.layer_idx) + self.num_layers = len(self.layers) + assert self.num_layers > 0, "No layers in the series" + assert self.prefetch_step >= 0 and self.prefetch_step <= max( + 0, self.num_layers - + 2), "prefetch_step must be in [0, num_layers - 2]" + self.start_layer = self.layers[0].layer_idx + self.end_layer = self.layers[-1].layer_idx + 1 + + for layer_idx in range(self.start_layer, self.end_layer): + layer = self.layers[layer_idx - self.start_layer] + assert layer.layer_idx == layer_idx, "layer_idx must be consecutive" + is_source = self.is_source(layer_idx) + # If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. + if not is_source: + layer.weight.set_(torch.empty_like(self.dummy_weight)) + # Broadcast to get the true weight. + dist.broadcast(layer.weight, + src=self.group.ranks[layer_idx % + self.group.world_size], + group=self.group.device_group) + # Call `process_weights_after_loading` from the quant method. + layer.post_method(layer.layer) + step = layer_idx - self.start_layer + if step < self.prefetch_step: + # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. + self.shared_windows.append( + SharedWindowMetadata( + weight=layer.weight.clone().detach(), + data_layer_idx=layer_idx, + work=None, + )) + layer.window_idx = step + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not is_source: + layer.weight.set_(self.shared_windows[-1].weight) + else: + # Build one more window for prefetch. The weight is useless, so just keep the shape. + if step == self.prefetch_step: + self.shared_windows.append( + SharedWindowMetadata( + weight=torch.empty_like(layer.weight), + data_layer_idx=-1, + work=None, + )) + # When the layer not intended to be stored in this device, dispose the tensor. + if not is_source: + dispose_tensor(layer.weight) + # Dispose the dummy tensor since it's no longer needed. + dispose_tensor(self.dummy_weight) + + def reach_layer(self, layer_idx: int): + # The index of the layer to be prefetched. + next_layer_idx = (layer_idx + self.prefetch_step + ) % self.num_layers + self.start_layer + next_layer = self.layers[next_layer_idx - self.start_layer] + # The index of the window to store the weight for the coming layer. + next_layer.window_idx = self.window_offset + window = self.shared_windows[next_layer.window_idx] + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not self.is_source(next_layer_idx): + next_layer.weight.set_(window.weight) + # Update `window_offset` by rolling one step. + self.window_offset = (self.window_offset + 1) % (self.prefetch_step + + 1) + assert window.data_layer_idx != next_layer_idx + window.data_layer_idx = next_layer_idx + # Start asynchronous broadcast work. + window.work = dist.broadcast( + next_layer.weight, + src=self.group.ranks[next_layer_idx % self.group.world_size], + group=self.group.device_group, + async_op=True) + + def wait_weight(self, layer_idx: int): + # Find the asynchronous broadcast work and wait for it. + assert self.shared_windows + window = self.shared_windows[self.layers[layer_idx - + self.start_layer].window_idx] + # Make sure the data in the corresponding shared window is for the current layer. + assert window.data_layer_idx == layer_idx + if window.work is not None: + window.work.wait() + window.work = None + + +@dataclass +class LayerExternalMetadata: + """External metadata for a layer. + """ + series: SeriesMetadata + layer_idx: int + + +_series_dict: dict[str, SeriesMetadata] = {} + +_layer_external_dict: dict[int, LayerExternalMetadata] = {} + + +def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, + layer_idx: int) -> Callable: + + def wrapped_forward(*args, **kwargs): + # Wait for the weight. + series.wait_weight(layer_idx) + return forward(*args, **kwargs) + + return wrapped_forward + + +""" +Register linear layers into a shared storage series. + +In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. + +After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. + +During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. + +Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: +- start_layer is the index of the first layer in the series (inclusive). +- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). +- total_layers = end_layer - start_layer +- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer + +To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. + +Arguments: + series_name: This name identifies which series this layer belongs to. + group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. + layer: The linear layer object to register. + prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. +""" + + +def register_layer_to_shared_weight_series( + series_name: str, + group: GroupCoordinator, + layer: LinearBase, + prefetch_step: int = 1, +): + global _series_dict + if series_name not in _series_dict: + _series_dict[series_name] = SeriesMetadata( + group=group, + start_layer=0, + end_layer=0, + num_layers=0, + prefetch_step=prefetch_step, + dummy_weight=torch.empty_like(layer.weight), + layers=[], + shared_windows=[], + window_offset=prefetch_step, + ) + series = _series_dict[series_name] + assert layer.quant_method is not None + layer_idx = extract_layer_index(layer.prefix) + series.layers.append( + LayerMetadata( + layer_idx=layer_idx, + layer=layer, + post_method=layer.quant_method.process_weights_after_loading, + weight=layer.weight, + window_idx=-1, + )) + # Discard the original `process_weights_after_loading` method such that it won't be called by others. + layer.quant_method.process_weights_after_loading = lambda layer: None + # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. + if not series.is_source(layer_idx): + dispose_tensor(layer.weight) + layer.weight.weight_loader = lambda *args, **kwargs: None + layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx) + global _layer_external_dict + _layer_external_dict[id(layer)] = LayerExternalMetadata( + series=series, + layer_idx=layer_idx, + ) + + +def post_process_after_loading_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.post_process_after_loading() + + +def reach_layer_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.reach_layer(ext.layer_idx) + + +def is_hidden_layer(vllm_config, layer: LinearBase) -> bool: + num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + layer_idx = extract_layer_index(layer.prefix) + return layer_idx < num_hidden_layers \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 85031bf6..16c0f68d 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1067,3 +1067,15 @@ def refresh_block_size(vllm_config): "Block size is set to 128 if prefix cache or chunked prefill is enabled." ) cache_config.block_size = 128 + + +def dispose_layer(layer: Any): + for attr_name in dir(layer): + attr_value = getattr(layer, attr_name) + if isinstance(attr_value, torch.Tensor): + dispose_tensor(attr_value) + + +def replace_layer(original_layer: Any, new_layer: Any): + original_layer.__class__ = new_layer.__class__ + original_layer.__dict__ = new_layer.__dict__