init v0.11.0rc0
This commit is contained in:
120
vllm_ascend/torchair/ops/sequence_parallel.py
Normal file
120
vllm_ascend/torchair/ops/sequence_parallel.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class MetadataForPadding:
|
||||
|
||||
def __init__(self,
|
||||
padding_flag=False,
|
||||
lengths_sum_padding=0,
|
||||
lengths_sum_unpadding=0,
|
||||
pad_size=0,
|
||||
not_dummy_and_is_prefill=False):
|
||||
self.padding_flag = padding_flag
|
||||
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||
|
||||
self.lengths_sum_padding = lengths_sum_padding
|
||||
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||
self.pad_size = pad_size
|
||||
|
||||
self.tp_size = get_tp_group().world_size
|
||||
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||
|
||||
assert self.lengths_sum_padding % self.tp_size == 0
|
||||
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||
|
||||
self.mc2_mask = torch.zeros(
|
||||
self.lengths_sum_padding,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type,
|
||||
)
|
||||
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||
|
||||
def padding_aligned_reduce_scatter(self,
|
||||
data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||
padded_data, 0)
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
def allgather_unpadding_aligned(self,
|
||||
padded_data: torch.Tensor) -> torch.Tensor:
|
||||
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||
padded_data, 0)
|
||||
if self.padding_flag:
|
||||
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||
else:
|
||||
unpadding_data = padded_data_allgather
|
||||
return unpadding_data
|
||||
|
||||
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||
start = self.tp_rank_in_group * self.slice_size
|
||||
end = start + self.slice_size
|
||||
slice_data = padded_data[start:end]
|
||||
|
||||
return slice_data
|
||||
|
||||
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
# padded_data = data
|
||||
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||
|
||||
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
|
||||
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||
if not enable_sequence_parallelism:
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
|
||||
is_perifll = 0
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata,
|
||||
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||
is_perifll = 1
|
||||
if hasattr(attn_metadata,
|
||||
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||
is_perifll = 1
|
||||
|
||||
if is_perifll:
|
||||
lengths_sum_unpadding = input_ids.shape[0]
|
||||
lengths_sum_padding = (
|
||||
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||
if lengths_sum_unpadding == lengths_sum_padding:
|
||||
padding_flag = False
|
||||
else:
|
||||
padding_flag = True
|
||||
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||
_metadata_for_padding = MetadataForPadding(
|
||||
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||
lengths_sum_padding=lengths_sum_padding,
|
||||
padding_flag=padding_flag,
|
||||
pad_size=pad_size,
|
||||
not_dummy_and_is_prefill=True)
|
||||
|
||||
return _metadata_for_padding
|
||||
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
245
vllm_ascend/torchair/ops/shared_weight_layer.py
Normal file
245
vllm_ascend/torchair/ops/shared_weight_layer.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerMetadata:
|
||||
"""Metadata for a layer.
|
||||
"""
|
||||
layer: Optional[LinearBase] # The layer object.
|
||||
post_method: Callable[[
|
||||
torch.nn.Module
|
||||
], None] # The `process_weights_after_loading` method from the quant method.
|
||||
weight: torch.Tensor # The weight tensor.
|
||||
window_idx: int # The index of the window.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
num_layers: int
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
return layer_idx % self.group.world_size == self.group.rank_in_group
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
return
|
||||
for layer_idx in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[layer_idx - self.start_layer]
|
||||
is_source = self.is_source(layer_idx)
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
|
||||
if not is_source:
|
||||
layer.weight.set_(torch.empty_like(self.dummy_weight))
|
||||
# Broadcast to get the true weight.
|
||||
dist.broadcast(layer.weight,
|
||||
src=self.group.ranks[layer_idx %
|
||||
self.group.world_size],
|
||||
group=self.group.device_group)
|
||||
assert layer.layer is not None
|
||||
# Call `process_weights_after_loading` from the quant method.
|
||||
layer.post_method(layer.layer)
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
))
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
))
|
||||
# When the layer not intended to be stored in this device, dispose the tensor.
|
||||
if not is_source:
|
||||
dispose_tensor(layer.weight)
|
||||
|
||||
dispose_tensor(self.dummy_weight)
|
||||
|
||||
def reach_layer(self, layer_idx: int):
|
||||
# The index of the layer to be prefetched.
|
||||
next_layer_idx = (layer_idx + self.prefetch_step
|
||||
) % self.num_layers + self.start_layer
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
# Update `window_offset` by rolling one step.
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
|
||||
1)
|
||||
assert window.data_layer_idx != next_layer_idx
|
||||
window.data_layer_idx = next_layer_idx
|
||||
# Start asynchronous broadcast work.
|
||||
window.work = dist.broadcast(
|
||||
next_layer.weight,
|
||||
src=self.group.ranks[next_layer_idx % self.group.world_size],
|
||||
group=self.group.device_group,
|
||||
async_op=True)
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
window.work = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerExternalMetadata:
|
||||
"""External metadata for a layer.
|
||||
"""
|
||||
series: SeriesMetadata
|
||||
layer_idx: int
|
||||
|
||||
|
||||
_series_dict: dict[str, SeriesMetadata] = {}
|
||||
|
||||
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
|
||||
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
layer_idx: int) -> Callable:
|
||||
|
||||
def wrapped_forward(*args, **kwargs):
|
||||
# Wait for the weight.
|
||||
series.wait_weight(layer_idx)
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
return wrapped_forward
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
|
||||
start_layer: The index of the first layer in the series (inclusive).
|
||||
end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
|
||||
layer_idx: The index of the current layer.
|
||||
layer: The linear layer object to register.
|
||||
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
layer_idx: int,
|
||||
layer: LinearBase,
|
||||
prefetch_step: int = 1,
|
||||
):
|
||||
global _series_dict
|
||||
if series_name not in _series_dict:
|
||||
num_layers = end_layer - start_layer
|
||||
assert num_layers > 0
|
||||
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2
|
||||
_series_dict[series_name] = SeriesMetadata(
|
||||
group=group,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
num_layers=num_layers,
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[
|
||||
LayerMetadata(
|
||||
layer=None,
|
||||
post_method=lambda layer: None,
|
||||
weight=torch.empty([]),
|
||||
window_idx=-1,
|
||||
) for _ in range(num_layers)
|
||||
],
|
||||
shared_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
assert layer.quant_method is not None
|
||||
series.layers[layer_idx - start_layer] = LayerMetadata(
|
||||
layer=layer,
|
||||
post_method=layer.quant_method.process_weights_after_loading,
|
||||
weight=layer.weight,
|
||||
window_idx=-1,
|
||||
)
|
||||
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
|
||||
layer.quant_method.process_weights_after_loading = lambda layer: None
|
||||
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
|
||||
if not series.is_source(layer_idx):
|
||||
dispose_tensor(layer.weight)
|
||||
layer.weight.weight_loader = lambda *args, **kwargs: None
|
||||
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
|
||||
global _layer_external_dict
|
||||
_layer_external_dict[id(layer)] = LayerExternalMetadata(
|
||||
series=series,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
37
vllm_ascend/torchair/ops/torchair_activation.py
Normal file
37
vllm_ascend/torchair/ops/torchair_activation.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""AscendSiluAndMul forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
if is_310p():
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
@@ -40,17 +40,18 @@ from vllm.model_executor.layers.quantization.base_config import \
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
determine_default_log2phy_map)
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_ascend_soc_version,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
get_rm_router_logits_state, is_310p,
|
||||
vllm_version_is)
|
||||
|
||||
|
||||
def torchair_fused_experts_with_mc2(
|
||||
@@ -802,6 +803,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
@@ -883,6 +885,8 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
||||
fused_moe_state = FusedMoEState.All2All
|
||||
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return torchair_fused_experts_with_mc2(
|
||||
@@ -1013,45 +1017,70 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
expert_map_path = ascend_config.expert_map_path
|
||||
if expert_map_path and os.path.exists(expert_map_path):
|
||||
# moe expert load balance
|
||||
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
|
||||
self.global_num_experts)
|
||||
self.local_num_experts, self.expert_map = \
|
||||
expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id,
|
||||
get_ep_group().rank_in_group)
|
||||
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id,
|
||||
get_ep_group().rank_in_group)
|
||||
self.global_redundant_expert_num = \
|
||||
expert_load_balancer.get_global_redundant_expert_num()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.expert_map_path = ascend_config.expert_map_path
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||
# static eplb initializing with expert_map_path
|
||||
if self.expert_map_path and os.path.exists(
|
||||
self.expert_map_path) and os.access(self.expert_map_path,
|
||||
os.R_OK):
|
||||
self.expert_load_balancer = ExpertLoadBalancer(
|
||||
self.expert_map_path, self.global_num_experts)
|
||||
self.local_num_experts, self.expert_map = (
|
||||
self.expert_load_balancer.get_rank_placement_map(
|
||||
self.moe_instance_id, self.ep_rank))
|
||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
self.moe_instance_id, self.ep_rank).npu()
|
||||
self.global_redundant_expert_num = (
|
||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||
else:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
# init moe.
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
self.ep_size, self.ep_rank, self.global_num_experts)
|
||||
# dynamic eplb initializing with not expert_map_path
|
||||
if self.dynamic_eplb:
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||
self.global_redundant_expert_num)
|
||||
self.log2phy = determine_default_log2phy_map(
|
||||
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||
self.global_redundant_expert_num)
|
||||
local_num_experts = (torch.sum(self.expert_map != -1)
|
||||
if self.expert_map is not None else num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and \
|
||||
self.multistream_overlap_shared_expert = \
|
||||
ascend_config.multistream_overlap_shared_expert and \
|
||||
self.torchair_graph_enabled
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
self.moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=params_dtype,
|
||||
)
|
||||
if quant_config is None:
|
||||
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
|
||||
self.moe)
|
||||
@@ -1066,8 +1095,11 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
if self.expert_map is not None else num_experts
|
||||
self.moe_load = None
|
||||
local_num_experts = (torch.sum(self.expert_map != -1)
|
||||
if self.expert_map is not None else num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
@@ -1126,23 +1158,25 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
forward_context = get_forward_context()
|
||||
fused_moe_state = forward_context.fused_moe_state
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
||||
fused_moe_state = FusedMoEState.All2All
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
if self.enable_multistream_moe:
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod
|
||||
if self.multistream_overlap_shared_expert:
|
||||
if not self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
if hasattr(self.quant_method, "quant_method") and \
|
||||
isinstance(self.quant_method.quant_method,
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod
|
||||
) and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
@@ -1160,31 +1194,33 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
if fused_moe_state in {FusedMoEState.MC2}:
|
||||
padding_size = forward_context.padded_num_tokens
|
||||
else:
|
||||
# TODO: Determine if we can remove the padding
|
||||
padding_size = tp_size
|
||||
if num_tokens < padding_size and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, padding_size - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits, (0, 0, 0, padding_size - num_tokens))
|
||||
]):
|
||||
if tp_size > 1:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if not self.enable_shared_expert_dp:
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
if not replace_allreduce:
|
||||
if fused_moe_state in {FusedMoEState.MC2}:
|
||||
padding_size = forward_context.padded_num_tokens
|
||||
else:
|
||||
# TODO: Determine if we can remove the padding
|
||||
padding_size = tp_size
|
||||
if num_tokens < padding_size and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, padding_size - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits, (0, 0, 0, padding_size - num_tokens))
|
||||
if tp_size > 1:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if not self.enable_shared_expert_dp:
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
@@ -1206,8 +1242,12 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
if vllm_version_is("0.10.2"):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
else:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
@@ -1236,7 +1276,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
and self.multistream_overlap_shared_expert and not is_prefill else
|
||||
None,
|
||||
mc2_mask=mc2_mask,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
@@ -1246,6 +1287,11 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if self.dynamic_eplb and isinstance(
|
||||
e_hidden_states, tuple) and len(e_hidden_states) == 3:
|
||||
self.moe_load += e_hidden_states[2] if e_hidden_states[1] == 0 else \
|
||||
torch.cat(e_hidden_states[2][:1], e_hidden_states[2][1:] - e_hidden_states[2][:-1])
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
@@ -1269,8 +1315,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
final_hidden_states = final_hidden_states[start:end, :]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = data_parallel_reduce_scatter(
|
||||
e_hidden_states, dim=0)
|
||||
final_hidden_states = get_dp_group().reduce_scatter(
|
||||
e_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
@@ -1290,6 +1336,19 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
return final_hidden_states
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
self.expert_map = new_expert_map
|
||||
|
||||
def get_map(self):
|
||||
return self.expert_map
|
||||
|
||||
def get_log2phy_map(self):
|
||||
return self.logical_to_physical_map
|
||||
|
||||
def clear_moe_load(self):
|
||||
if self.moe_load is not None:
|
||||
self.moe_load.zero_()
|
||||
|
||||
# ----------------------------------------- TBO-related --------------------------------------------
|
||||
|
||||
def _forward_ms_fused_moe_comp(
|
||||
|
||||
51
vllm_ascend/torchair/ops/torchair_layernorm.py
Normal file
51
vllm_ascend/torchair/ops/torchair_layernorm.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torchair_rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""AscendRMSNorm forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x
|
||||
@@ -62,7 +62,7 @@ def rope_forward_oot(
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
@@ -93,10 +93,7 @@ def native_rope_deepseek_forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
||||
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
@@ -211,8 +208,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
@@ -232,9 +228,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len * self.scaling_factor,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
@@ -365,8 +359,7 @@ def deepseek_rope_init_func(
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
self.max_seq_len = max_position_embeddings
|
||||
_set_cos_sin_cache(self,
|
||||
max_position_embeddings,
|
||||
dtype=dtype,
|
||||
device="npu")
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
|
||||
|
||||
Reference in New Issue
Block a user