[refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
"layer_sharding": ["o_proj", "q_b_proj"]
}'
```
This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
from functools import lru_cache
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -7,6 +8,8 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_shard_weight_group
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
@@ -26,17 +29,17 @@ class LayerMetadata:
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
class ShardWindowMetadata:
|
||||
"""Metadata for a shard window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
weight: torch.Tensor # The weight tensor to be shard by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""Metadata for a weight shard series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
@@ -45,8 +48,8 @@ class SeriesMetadata:
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
shard_windows: list[
|
||||
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
@@ -54,7 +57,7 @@ class SeriesMetadata:
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
if self.shard_windows:
|
||||
return
|
||||
|
||||
self.layers.sort(key=lambda x: x.layer_idx)
|
||||
@@ -83,8 +86,8 @@ class SeriesMetadata:
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
self.shard_windows.append(
|
||||
ShardWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
@@ -92,12 +95,12 @@ class SeriesMetadata:
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
layer.weight.set_(self.shard_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
self.shard_windows.append(
|
||||
ShardWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
@@ -115,7 +118,7 @@ class SeriesMetadata:
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
window = self.shard_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
@@ -133,10 +136,10 @@ class SeriesMetadata:
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert self.shard_windows
|
||||
window = self.shard_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shard window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
@@ -168,13 +171,13 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
Register linear layers into a shard storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- start_layer is the index of the first layer in the series (inclusive).
|
||||
@@ -182,7 +185,7 @@ Note: The layers are managed as a circular buffer. The index of the layer to pre
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
@@ -192,7 +195,7 @@ Arguments:
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
def register_layer_to_shard_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
layer: LinearBase,
|
||||
@@ -208,7 +211,7 @@ def register_layer_to_shared_weight_series(
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[],
|
||||
shared_windows=[],
|
||||
shard_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
@@ -236,17 +239,42 @@ def register_layer_to_shared_weight_series(
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
def post_process_after_loading_for_shard_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
def reach_layer_for_shard_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
|
||||
|
||||
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
|
||||
num_hidden_layers = vllm_config.model_config.hf_text_config.num_hidden_layers
|
||||
def wait_layer_for_shard_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.wait_weight(ext.layer_idx)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_model_num_hidden_layers() -> int:
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.model_config.get_total_num_hidden_layers()
|
||||
|
||||
|
||||
def is_hidden_layer(layer: LinearBase) -> bool:
|
||||
num_hidden_layers = get_current_model_num_hidden_layers()
|
||||
layer_idx = extract_layer_index(layer.prefix)
|
||||
return layer_idx < num_hidden_layers
|
||||
|
||||
|
||||
def register_all_layers_to_shard_weight_series(
|
||||
layer_sharding: List[LinearBase], ):
|
||||
for curr_layer in (layer_sharding or []):
|
||||
if is_hidden_layer(curr_layer):
|
||||
layer_name = curr_layer.prefix.split('.')[-1]
|
||||
register_layer_to_shard_weight_series(
|
||||
series_name=layer_name,
|
||||
group=get_shard_weight_group(),
|
||||
layer=curr_layer,
|
||||
prefetch_step=1,
|
||||
)
|
||||
Reference in New Issue
Block a user