diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 1b68590d..f8b9d1cd 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -253,12 +253,24 @@ def select_moe_comm_method(num_tokens: int, ascend_config = get_ascend_config() dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes - fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group( - ).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model) + # TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs + # TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8 + fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and ( + not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: - moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2 + fused_decode_enable = fused_mc2_enable + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + fused_decode_enable = fused_mc2_enable and get_ep_group( + ).world_size <= 16 and (not is_mtp_model) + moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 else: - moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL + fused_prefill_enable = fused_mc2_enable + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + fused_prefill_enable = fused_mc2_enable and get_ep_group( + ).world_size <= 16 and (not is_mtp_model) + elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: + fused_prefill_enable = False + moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL else: raise ValueError(f"Unsupported soc_version: {soc_version}") diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index c3c5e967..75024663 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -135,7 +135,13 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to anbale dynamic EPLB "DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), - # Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator) + # Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator) + # 0, or not set: default ALLTOALL and MC2 will be used. + # 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator. + # `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb. + # 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator. + # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer + # with W8A8, non-dynamic-eplb. And MTP layer must be W8A8. "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')), } diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 9d913f63..b4cbbb48 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -345,7 +345,7 @@ class AscendFusedMoE(FusedMoE): shared_out = fc3_context.shared_experts(hidden_states) # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ and not shared_expert_dp_enabled(): shared_out = tensor_model_parallel_all_reduce(shared_out) set_flash_common3_context(shared_out=shared_out) diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index e9dc7b0a..30d1e5c1 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -291,9 +291,9 @@ class FusedMC2CommImpl(MoECommMethod): assert not ( w1_scale is None or w2_scale is None ), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." - out = torch.empty_like(hidden_states) if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + out = torch.empty_like(hidden_states) torch.ops._C_ascend.dispatch_ffn_combine( x=hidden_states, weight1=w1[0], @@ -307,7 +307,21 @@ class FusedMC2CommImpl(MoECommMethod): out=out, ) elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: - raise NotImplementedError() + assert expert_map is not None, "expert_map cannot be None." + out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( + x=hidden_states, + expert_ids=topk_ids, + gmm1_permuted_weight=w1[0], + gmm1_permuted_weight_scale=w1_scale[0], + gmm2_weight=w2[0], + gmm2_weight_scale=w2_scale[0], + expert_smooth_scales=None, + expert_scales=topk_weights.to(torch.float32), + group_ep=self.token_dispatcher.moe_all_to_all_group_name, + ep_rank_size=self.token_dispatcher.ep_world_size, + ep_rank_id=self.token_dispatcher.ep_rank_id, + moe_expert_num=len(expert_map), + global_bs=self.token_dispatcher.fused_global_bs) else: raise ValueError( f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 1b18a488..aeb751d0 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -125,6 +125,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512) num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size self.global_bs = num_tokens_per_tp_rank * self.ep_world_size + self.fused_global_bs = max_num_tokens * self.ep_world_size def get_dispatch_mc2_kwargs( self, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index e32360ce..da6d3a69 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -231,6 +231,10 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method + # When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale + w2_weight_scale_fp32_flag = ( + get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 + and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2) if self.dynamic_eplb: w1 = layer.w13_weight_list w1_scale = layer.w13_weight_scale_fp32_list @@ -240,7 +244,10 @@ class AscendW8A8DynamicFusedMoEMethod: w1 = [layer.w13_weight] w1_scale = [layer.w13_weight_scale_fp32] w2 = [layer.w2_weight] - w2_scale = [layer.w2_weight_scale] + w2_scale = [ + layer.w2_weight_scale_fp32 + if w2_weight_scale_fp32_flag else layer.w2_weight_scale + ] fused_scale_flag = (get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 @@ -279,6 +286,8 @@ class AscendW8A8DynamicFusedMoEMethod: layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to( + torch.float32) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1)