diff --git a/vllm_ascend/distributed/communicator.py b/vllm_ascend/distributed/communicator.py index f8291f8..7c14bef 100644 --- a/vllm_ascend/distributed/communicator.py +++ b/vllm_ascend/distributed/communicator.py @@ -14,10 +14,10 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from typing import Optional +from typing import List, Optional import torch -from torch.distributed import ProcessGroup +import torch.distributed as dist from vllm.distributed.device_communicators.base_device_communicator import \ DeviceCommunicatorBase @@ -25,11 +25,51 @@ from vllm.distributed.device_communicators.base_device_communicator import \ class NPUCommunicator(DeviceCommunicatorBase): def __init__(self, - cpu_group: ProcessGroup, + cpu_group: dist.ProcessGroup, device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, + device_group: Optional[dist.ProcessGroup] = None, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator # init device according to rank self.device = torch.npu.current_device() + + def all_to_all(self, + input_: torch.Tensor, + scatter_dim: int = 0, + gather_dim: int = -1, + scatter_sizes: Optional[List[int]] = None, + gather_sizes: Optional[List[int]] = None) -> torch.Tensor: + + if scatter_dim < 0: + scatter_dim += input_.dim() + if gather_dim < 0: + gather_dim += input_.dim() + + if scatter_sizes is not None and gather_sizes is not None: + input_list = [ + t.contiguous() + for t in torch.split(input_, scatter_sizes, scatter_dim) + ] + output_list = [] + tensor_shape_base = input_list[self.rank].size() + for i in range(self.world_size): + tensor_shape = list(tensor_shape_base) + tensor_shape[gather_dim] = gather_sizes[i] + output_list.append( + torch.empty(tensor_shape, + dtype=input_.dtype, + device=input_.device)) + + else: + input_list = [ + t.contiguous() for t in torch.tensor_split( + input_, self.world_size, scatter_dim) + ] + output_list = [ + torch.empty_like(input_list[i]) for i in range(self.world_size) + ] + + dist.all_to_all(output_list, input_list, group=self.device_group) + output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() + return output_tensor diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5bf1126..5e97444 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -205,50 +205,66 @@ class CustomDeepseekV2MoE(nn.Module): ) CustomDeepseekV2MoE.top_k = config.num_experts_per_tok - vllm_config = get_current_vllm_config() self.dp_size = get_dp_group().world_size - batch_size = vllm_config.scheduler_config.max_num_seqs - params_dtype = torch.get_default_dtype() - self.final_hidden_states = torch.zeros( - [batch_size, config.hidden_size], dtype=params_dtype, device="npu") self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. if attn_metadata is None: # for profile run is_prefill = True + enable_force_load_balance = True else: is_prefill = attn_metadata.num_prefills > 0 + enable_force_load_balance = False num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill): - chunks = torch.chunk(hidden_states, - get_tp_group().world_size, - dim=0) - hidden_states = chunks[get_tp_group().rank_in_group] - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor - - if self.tp_size > 1: - if VLLM_ENABLE_MC2 and not is_prefill: - dist.all_gather_into_tensor(self.final_hidden_states, - final_hidden_states, self.tp_group) - final_hidden_states = self.final_hidden_states - else: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) + + if self.tp_size > 1: + # pass + num_tokens, hidden_size = hidden_states.shape + if num_tokens < self.tp_size: + target_size = self.tp_size + new_hidden_states = torch.empty([target_size, hidden_size], + dtype=hidden_states.dtype, + device=hidden_states.device) + new_hidden_states[:num_tokens] = hidden_states + hidden_states = new_hidden_states + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + local_hidden_states = chunk_hidden_states[self.tp_rank] + else: + local_hidden_states = hidden_states + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(local_hidden_states) + + router_hidden_states = self.experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k, + enable_force_load_balance=enable_force_load_balance, + ) * self.routed_scaling_factor + + if self.tp_size > 1: + dist.all_gather(list(chunk_hidden_states), router_hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens < self.tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + else: + final_hidden_states = router_hidden_states + + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 906d77c..fd882dd 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -18,7 +18,6 @@ from typing import Callable, Optional import torch -import torch.distributed as dist import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import tensor_model_parallel_all_reduce @@ -636,6 +635,7 @@ class AscendFusedMoE(FusedMoE): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_prefill: bool, + enable_force_load_balance: bool = False, top_k=None): assert self.quant_method is not None @@ -644,17 +644,8 @@ class AscendFusedMoE(FusedMoE): else: real_top_k = self.top_k - if self.dp_size > 1: - if VLLM_ENABLE_MC2 and not is_prefill: - ... - elif USING_LCCL_COM: # type: ignore - hidden_states = get_dp_group().all_gather( - hidden_states, 0, False) - router_logits = get_dp_group().all_gather( - router_logits, 0, False) - else: - hidden_states = get_dp_group().all_gather(hidden_states, 0) - router_logits = get_dp_group().all_gather(router_logits, 0) + if VLLM_ENABLE_MC2 and not is_prefill: + ... # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -671,17 +662,12 @@ class AscendFusedMoE(FusedMoE): custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, - is_prefill=is_prefill) + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance, + dp_size=self.dp_size) - if self.dp_size > 1: - if VLLM_ENABLE_MC2 and not is_prefill: - ... - else: - final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - final_hidden_states, - "sum", - scatter_dim=0, - group=get_dp_group().device_group) + if VLLM_ENABLE_MC2 and not is_prefill: + ... if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index af19d4a..9369596 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # + # patch_utils should be the first import, because it will be used by other # patch files. import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip +import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_distributed.py b/vllm_ascend/patch/worker/patch_common/patch_distributed.py new file mode 100644 index 0000000..846d82c --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_distributed.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional + +import torch +import vllm +from vllm.distributed.parallel_state import GroupCoordinator + + +class GroupCoordinatorPatch(GroupCoordinator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def all_to_all(self, + input_: torch.Tensor, + scatter_dim: int = 0, + gather_dim: int = -1, + scatter_sizes: Optional[List[int]] = None, + gather_sizes: Optional[List[int]] = None) -> torch.Tensor: + if self.world_size == 1: + return input_ + assert -input_.dim() <= scatter_dim < input_.dim(), ( + f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}" + ) + assert -input_.dim() <= gather_dim < input_.dim(), ( + f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}" + ) + return self.device_communicator.all_to_all(input_, scatter_dim, + gather_dim, scatter_sizes, + gather_sizes) + + +vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving \ No newline at end of file diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 499e236..1aededd 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -321,14 +321,15 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, + enable_force_load_balance: bool = False, + dp_size: int = 1, **kwargs, ) -> torch.Tensor: - return self.quant_method.apply(layer, x, router_logits, top_k, - renormalize, use_grouped_topk, - global_num_experts, expert_map, - topk_group, num_expert_group, - custom_routing_function, scoring_func, - e_score_correction_bias, is_prefill) + return self.quant_method.apply( + layer, x, router_logits, top_k, renormalize, use_grouped_topk, + global_num_experts, expert_map, topk_group, num_expert_group, + custom_routing_function, scoring_func, e_score_correction_bias, + is_prefill, enable_force_load_balance, dp_size) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 97bddba..5d2b442 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,15 +15,19 @@ # limitations under the License. # -import os from typing import Any, Callable, Dict, List, Optional import torch +import torch.distributed as dist import torch_npu +from vllm.distributed import GroupCoordinator +import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts +VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 + def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w1: torch.Tensor, @@ -68,24 +72,18 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], - split_item=3, + scale=[w1_scale], + per_token_scale=[pertoken_scale], + split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=torch.int32)[0] + output_dtype=w2_scale.dtype)[0] # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( @@ -201,6 +199,132 @@ def fused_experts_with_mc2( return hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. +def fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + device = hidden_states.device + + if expert_map is not None: + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, + -1).sum(-1) + + gather_sizes = torch.empty_like(scatter_sizes) + dist.all_to_all_single(gather_sizes, + scatter_sizes, + group=ep_group.device_group) + scatter_size_list = scatter_sizes.cpu().tolist() + gather_size_list = gather_sizes.cpu().tolist() + + expanded_expert_idx = expanded_expert_idx % local_num_experts + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + scatter_size_list, + gather_size_list) + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, + scatter_size_list, + gather_size_list) + + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) + + if expert_map is not None: + resorted_idx = torch.argsort(sorted_idx) + hidden_states = hidden_states[resorted_idx] + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + gather_size_list, + scatter_size_list) + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -387,10 +511,10 @@ class AscendW8A8DynamicFusedMoEMethod: def __init__(self): self.transpose_weight = True - ep_group = get_ep_group() + self.ep_group = get_ep_group() try: - device_group = ep_group.device_group + device_group = self.ep_group.device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) @@ -457,6 +581,8 @@ class AscendW8A8DynamicFusedMoEMethod: scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, + enable_force_load_balance: bool = True, + dp_size: int = 1, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -491,7 +617,13 @@ class AscendW8A8DynamicFusedMoEMethod: e_score_correction_bias=e_score_correction_bias, ) - if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill: + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + if VLLM_ENABLE_MC2 and not is_prefill: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -503,7 +635,7 @@ class AscendW8A8DynamicFusedMoEMethod: top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - else: + elif dp_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, @@ -513,6 +645,17 @@ class AscendW8A8DynamicFusedMoEMethod: topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + else: + return fused_experts_with_all2all(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group) def process_weights_after_loading(self, layer): if self.transpose_weight: @@ -521,7 +664,7 @@ class AscendW8A8DynamicFusedMoEMethod: layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( - layer.w13_weight_scale.data.shape[0], -1).to(torch.float32) + layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(