From cd1ffbb6cd88a3f265027424bd3cca74d1efb1ea Mon Sep 17 00:00:00 2001 From: clrs97 <33470839+clrs97@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:16:41 +0800 Subject: [PATCH] [1/N][Feat] Cut down memory usage for o_proj in DeepSeek (#2931) ### What this PR does / why we need it? To cut down the memory usage of large weight matrices, we often rely on various linear operations: - `ReplicatedLinear`: Stores the entire matrix, consuming excessive memory. - `RowParallelLinear`: Requires an `all_reduce` to merge answer, introducing additional communication overhead and potential accuracy loss. Each token is handled across multiple devices rather than a single device, which is undesirable in SP scenario. - ... Furthermore, in multi-way Data Parallelism (DP) configurations, layers typically store redundant weight copies. This PR introduces a shared-weight plugin for layers inheriting from `LinearBase`. It offers the following advantages: - It evenly distributes a set of layers with identical structures across devices. Each layer retains its complete weights, eliminating redundant memory usage. - It supports asynchronous broadcasting to prefetch weights for upcoming layers. - It preserves the custom `process_weights_after_loading()` method to make keeping NZ format possible. - It is compatible with any linear class that inherits from `LinearBase`, thereby preserving all the features of the original linear implementation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM main: https://github.com/vllm-project/vllm/commit/f4a948f33f8766c843e0d59d9ec1ab44f7b2bfcf - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/f225ea7dd98e9f29752e5c032cd4a8ee1d712f16 --------- Signed-off-by: clrs97 <524936896@qq.com> Co-authored-by: CalvinXKY --- .../torchair/ops/shared_weight_layer.py | 245 ++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 vllm_ascend/torchair/ops/shared_weight_layer.py diff --git a/vllm_ascend/torchair/ops/shared_weight_layer.py b/vllm_ascend/torchair/ops/shared_weight_layer.py new file mode 100644 index 0000000..6ab29af --- /dev/null +++ b/vllm_ascend/torchair/ops/shared_weight_layer.py @@ -0,0 +1,245 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.model_executor.layers.linear import LinearBase + + +def dispose_tensor(x: torch.Tensor): + x.set_(torch.empty([], device=x.device, dtype=x.dtype)) + + +@dataclass +class LayerMetadata: + """Metadata for a layer. + """ + layer: Optional[LinearBase] # The layer object. + post_method: Callable[[ + torch.nn.Module + ], None] # The `process_weights_after_loading` method from the quant method. + weight: torch.Tensor # The weight tensor. + window_idx: int # The index of the window. + + +@dataclass +class SharedWindowMetadata: + """Metadata for a shared window. + """ + weight: torch.Tensor # The weight tensor to be shared by layers. + data_layer_idx: int # The index of the layer this window's weight is equal to. + work: Optional[torch.distributed.Work] # The asynchronous broadcast work. + + +@dataclass +class SeriesMetadata: + """Metadata for a weight shared series. + """ + group: GroupCoordinator + start_layer: int + end_layer: int + num_layers: int + prefetch_step: int + dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. + layers: list[LayerMetadata] + shared_windows: list[ + SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. + window_offset: int # The index of the window for the next coming layer. + + def is_source(self, layer_idx) -> bool: + return layer_idx % self.group.world_size == self.group.rank_in_group + + def post_process_after_loading(self): + # This method only needs to be called once per series. + if self.shared_windows: + return + for layer_idx in range(self.start_layer, self.end_layer): + layer = self.layers[layer_idx - self.start_layer] + is_source = self.is_source(layer_idx) + # If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. + if not is_source: + layer.weight.set_(torch.empty_like(self.dummy_weight)) + # Broadcast to get the true weight. + dist.broadcast(layer.weight, + src=self.group.ranks[layer_idx % + self.group.world_size], + group=self.group.device_group) + assert layer.layer is not None + # Call `process_weights_after_loading` from the quant method. + layer.post_method(layer.layer) + step = layer_idx - self.start_layer + if step < self.prefetch_step: + # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. + self.shared_windows.append( + SharedWindowMetadata( + weight=layer.weight.clone().detach(), + data_layer_idx=layer_idx, + work=None, + )) + layer.window_idx = step + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not is_source: + layer.weight.set_(self.shared_windows[-1].weight) + else: + # Build one more window for prefetch. The weight is useless, so just keep the shape. + if step == self.prefetch_step: + self.shared_windows.append( + SharedWindowMetadata( + weight=torch.empty_like(layer.weight), + data_layer_idx=-1, + work=None, + )) + # When the layer not intended to be stored in this device, dispose the tensor. + if not is_source: + dispose_tensor(layer.weight) + + dispose_tensor(self.dummy_weight) + + def reach_layer(self, layer_idx: int): + # The index of the layer to be prefetched. + next_layer_idx = (layer_idx + self.prefetch_step + ) % self.num_layers + self.start_layer + next_layer = self.layers[next_layer_idx - self.start_layer] + # The index of the window to store the weight for the coming layer. + next_layer.window_idx = self.window_offset + window = self.shared_windows[next_layer.window_idx] + # When the layer not intended to be stored in this device, link to the corresponding window's tensor. + if not self.is_source(next_layer_idx): + next_layer.weight.set_(window.weight) + # Update `window_offset` by rolling one step. + self.window_offset = (self.window_offset + 1) % (self.prefetch_step + + 1) + assert window.data_layer_idx != next_layer_idx + window.data_layer_idx = next_layer_idx + # Start asynchronous broadcast work. + window.work = dist.broadcast( + next_layer.weight, + src=self.group.ranks[next_layer_idx % self.group.world_size], + group=self.group.device_group, + async_op=True) + + def wait_weight(self, layer_idx: int): + # Find the asynchronous broadcast work and wait for it. + assert self.shared_windows + window = self.shared_windows[self.layers[layer_idx - + self.start_layer].window_idx] + # Make sure the data in the corresponding shared window is for the current layer. + assert window.data_layer_idx == layer_idx + if window.work is not None: + window.work.wait() + window.work = None + + +@dataclass +class LayerExternalMetadata: + """External metadata for a layer. + """ + series: SeriesMetadata + layer_idx: int + + +_series_dict: dict[str, SeriesMetadata] = {} + +_layer_external_dict: dict[int, LayerExternalMetadata] = {} + + +def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, + layer_idx: int) -> Callable: + + def wrapped_forward(*args, **kwargs): + # Wait for the weight. + series.wait_weight(layer_idx) + return forward(*args, **kwargs) + + return wrapped_forward + + +""" +Register linear layers into a shared storage series. + +In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. + +After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. + +During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. + +Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: +- total_layers = end_layer - start_layer +- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer + +To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. + +Arguments: + series_name: This name identifies which series this layer belongs to. + group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. + start_layer: The index of the first layer in the series (inclusive). + end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). + layer_idx: The index of the current layer. + layer: The linear layer object to register. + prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. +""" + + +def register_layer_to_shared_weight_series( + series_name: str, + group: GroupCoordinator, + start_layer: int, + end_layer: int, + layer_idx: int, + layer: LinearBase, + prefetch_step: int = 1, +): + global _series_dict + if series_name not in _series_dict: + num_layers = end_layer - start_layer + assert num_layers > 0 + assert prefetch_step >= 0 and prefetch_step <= num_layers - 2 + _series_dict[series_name] = SeriesMetadata( + group=group, + start_layer=start_layer, + end_layer=end_layer, + num_layers=num_layers, + prefetch_step=prefetch_step, + dummy_weight=torch.empty_like(layer.weight), + layers=[ + LayerMetadata( + layer=None, + post_method=lambda layer: None, + weight=torch.empty([]), + window_idx=-1, + ) for _ in range(num_layers) + ], + shared_windows=[], + window_offset=prefetch_step, + ) + series = _series_dict[series_name] + assert layer.quant_method is not None + series.layers[layer_idx - start_layer] = LayerMetadata( + layer=layer, + post_method=layer.quant_method.process_weights_after_loading, + weight=layer.weight, + window_idx=-1, + ) + # Discard the original `process_weights_after_loading` method such that it won't be called by others. + layer.quant_method.process_weights_after_loading = lambda layer: None + # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. + if not series.is_source(layer_idx): + dispose_tensor(layer.weight) + layer.weight.weight_loader = lambda *args, **kwargs: None + layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx) + global _layer_external_dict + _layer_external_dict[id(layer)] = LayerExternalMetadata( + series=series, + layer_idx=layer_idx, + ) + + +def post_process_after_loading_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.post_process_after_loading() + + +def reach_layer_for_shared_weight_series(layer: LinearBase): + ext = _layer_external_dict[id(layer)] + ext.series.reach_layer(ext.layer_idx)