### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/ops/layer_shard_linear.py`|
|`vllm_ascend/ops/linear.py`|
|`vllm_ascend/ops/linear_op.py`|
|`vllm_ascend/worker/worker.py`|
| ` vllm_ascend/patch/worker/patch_bert.py` |
| ` vllm_ascend/patch/worker/patch_deepseek.py` |
| ` vllm_ascend/patch/worker/patch_distributed.py` |
| ` vllm_ascend/patch/worker/patch_module.py` |
| ` vllm_ascend/patch/worker/patch_multimodal_merge.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` |
| ` vllm_ascend/patch/worker/patch_rejection_sampler.py` |
| ` vllm_ascend/patch/worker/patch_rope.py` |
| ` vllm_ascend/patch/worker/patch_triton.py` |
| ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` |
| ` vllm_ascend/patch/worker/patch_v2_egale.py` |
|` vllm_ascend/worker/npu_input_batch.py`|
|` vllm_ascend/worker/v2/aclgraph_utils.py`|
|` vllm_ascend/worker/v2/attn_utils.py`|
|` vllm_ascend/worker/v2/model_runner.py`|
|` vllm_ascend/worker/v2/sample/gumbel.py`|
|` vllm_ascend/worker/v2/sample/penalties.py`|
|` vllm_ascend/worker/v2/sample/sampler.py`|
|` vllm_ascend/worker/v2/spec_decode/__init__.py`|
|` vllm_ascend/worker/v2/spec_decode/eagle.py`|
|` vllm_ascend/worker/v2/states.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -17,39 +17,38 @@ def dispose_tensor(x: torch.Tensor):
|
||||
|
||||
@dataclass
|
||||
class LayerMetadata:
|
||||
"""Metadata for a layer.
|
||||
"""
|
||||
"""Metadata for a layer."""
|
||||
|
||||
layer_idx: int # The index of the layer.
|
||||
layer: LinearBase # The layer object.
|
||||
post_method: Callable[[
|
||||
torch.nn.Module
|
||||
], None] # The `process_weights_after_loading` method from the quant method.
|
||||
post_method: Callable[[torch.nn.Module], None] # The `process_weights_after_loading` method from the quant method.
|
||||
weight: torch.Tensor # The weight tensor.
|
||||
window_idx: int # The index of the window.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardWindowMetadata:
|
||||
"""Metadata for a shard window.
|
||||
"""
|
||||
"""Metadata for a shard window."""
|
||||
|
||||
weight: torch.Tensor # The weight tensor to be shard by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
work: torch.distributed.Work | None # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shard series.
|
||||
"""
|
||||
"""Metadata for a weight shard series."""
|
||||
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
num_layers: int
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix.
|
||||
# All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shard_windows: list[
|
||||
ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
shard_windows: list[ShardWindowMetadata] # Shard windows for prefetching. The window size is (`prefetch_step` + 1),
|
||||
# as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
@@ -63,9 +62,9 @@ class SeriesMetadata:
|
||||
self.layers.sort(key=lambda x: x.layer_idx)
|
||||
self.num_layers = len(self.layers)
|
||||
assert self.num_layers > 0, "No layers in the series"
|
||||
assert self.prefetch_step >= 0 and self.prefetch_step <= max(
|
||||
0, self.num_layers -
|
||||
2), "prefetch_step must be in [0, num_layers - 2]"
|
||||
assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), (
|
||||
"prefetch_step must be in [0, num_layers - 2]"
|
||||
)
|
||||
self.start_layer = self.layers[0].layer_idx
|
||||
self.end_layer = self.layers[-1].layer_idx + 1
|
||||
|
||||
@@ -73,25 +72,27 @@ class SeriesMetadata:
|
||||
layer = self.layers[layer_idx - self.start_layer]
|
||||
assert layer.layer_idx == layer_idx, "layer_idx must be consecutive"
|
||||
is_source = self.is_source(layer_idx)
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call
|
||||
# won't affect other layers which also uses dummy weight.
|
||||
if not is_source:
|
||||
layer.weight.set_(torch.empty_like(self.dummy_weight))
|
||||
# Broadcast to get the true weight.
|
||||
dist.broadcast(layer.weight,
|
||||
src=self.group.ranks[layer_idx %
|
||||
self.group.world_size],
|
||||
group=self.group.device_group)
|
||||
dist.broadcast(
|
||||
layer.weight, src=self.group.ranks[layer_idx % self.group.world_size], group=self.group.device_group
|
||||
)
|
||||
# Call `process_weights_after_loading` from the quant method.
|
||||
layer.post_method(layer.layer)
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used
|
||||
# for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shard_windows.append(
|
||||
ShardWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
))
|
||||
)
|
||||
)
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
@@ -104,7 +105,8 @@ class SeriesMetadata:
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
))
|
||||
)
|
||||
)
|
||||
# When the layer not intended to be stored in this device, dispose the tensor.
|
||||
if not is_source:
|
||||
dispose_tensor(layer.weight)
|
||||
@@ -113,8 +115,7 @@ class SeriesMetadata:
|
||||
|
||||
def reach_layer(self, layer_idx: int):
|
||||
# The index of the layer to be prefetched.
|
||||
next_layer_idx = (layer_idx + self.prefetch_step
|
||||
) % self.num_layers + self.start_layer
|
||||
next_layer_idx = (layer_idx + self.prefetch_step) % self.num_layers + self.start_layer
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
@@ -123,8 +124,7 @@ class SeriesMetadata:
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
# Update `window_offset` by rolling one step.
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
|
||||
1)
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step + 1)
|
||||
assert window.data_layer_idx != next_layer_idx
|
||||
window.data_layer_idx = next_layer_idx
|
||||
# Start asynchronous broadcast work.
|
||||
@@ -132,13 +132,13 @@ class SeriesMetadata:
|
||||
next_layer.weight,
|
||||
src=self.group.ranks[next_layer_idx % self.group.world_size],
|
||||
group=self.group.device_group,
|
||||
async_op=True)
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shard_windows
|
||||
window = self.shard_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
window = self.shard_windows[self.layers[layer_idx - self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shard window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
@@ -148,8 +148,8 @@ class SeriesMetadata:
|
||||
|
||||
@dataclass
|
||||
class LayerExternalMetadata:
|
||||
"""External metadata for a layer.
|
||||
"""
|
||||
"""External metadata for a layer."""
|
||||
|
||||
series: SeriesMetadata
|
||||
layer_idx: int
|
||||
|
||||
@@ -159,9 +159,7 @@ _series_dict: dict[str, SeriesMetadata] = {}
|
||||
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
|
||||
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
layer_idx: int) -> Callable:
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, layer_idx: int) -> Callable:
|
||||
def wrapped_forward(*args, **kwargs):
|
||||
# Wait for the weight.
|
||||
series.wait_weight(layer_idx)
|
||||
@@ -173,23 +171,32 @@ def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
"""
|
||||
Register linear layers into a shard storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series.
|
||||
All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer
|
||||
is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
After loading the model, you must call `post_process_after_loading_for_shard_weight_series(layer)`
|
||||
on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shard_weight_series(layer)`
|
||||
for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages
|
||||
asynchronous weight prefetching. Each call to `reach_layer_for_shard_weight_series(current_layer)` method will
|
||||
trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- start_layer is the index of the first layer in the series (inclusive).
|
||||
- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
|
||||
- end_layer is the index of the last layer in the series (exclusive). Thus, the series includes all layers with
|
||||
indices in the range [start_layer, end_layer).
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers will be created for this series.
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shard tensor buffers
|
||||
will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group
|
||||
coordinator for each new series.
|
||||
layer: The linear layer object to register.
|
||||
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
|
||||
"""
|
||||
@@ -224,7 +231,8 @@ def register_layer_to_shard_weight_series(
|
||||
post_method=layer.quant_method.process_weights_after_loading,
|
||||
weight=layer.weight,
|
||||
window_idx=-1,
|
||||
))
|
||||
)
|
||||
)
|
||||
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
|
||||
layer.quant_method.process_weights_after_loading = lambda layer: None
|
||||
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
|
||||
@@ -257,6 +265,7 @@ def wait_layer_for_shard_weight_series(layer: LinearBase):
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_model_num_hidden_layers() -> int:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.model_config.get_total_num_hidden_layers()
|
||||
|
||||
@@ -268,10 +277,11 @@ def is_hidden_layer(layer: LinearBase) -> bool:
|
||||
|
||||
|
||||
def register_all_layers_to_shard_weight_series(
|
||||
layer_sharding: List[LinearBase], ):
|
||||
for curr_layer in (layer_sharding or []):
|
||||
layer_sharding: list[LinearBase],
|
||||
):
|
||||
for curr_layer in layer_sharding or []:
|
||||
if is_hidden_layer(curr_layer):
|
||||
layer_name = curr_layer.prefix.split('.')[-1]
|
||||
layer_name = curr_layer.prefix.split(".")[-1]
|
||||
register_layer_to_shard_weight_series(
|
||||
series_name=layer_name,
|
||||
group=get_shard_weight_group(),
|
||||
|
||||
Reference in New Issue
Block a user