diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 129e5eb..8886972 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -393,7 +393,7 @@ class CustomDeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = None - if not self.rm_router_logits: + if not self.rm_router_logits and not self.enable_multistream_moe: router_logits, _ = self.gate(hidden_states) experts_hidden_states = self.experts( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 61205ff..6b3338a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1334,6 +1334,21 @@ class AscendFusedMoE(FusedMoE): forward_context = get_forward_context() fused_moe_state = forward_context.fused_moe_state mc2_mask = forward_context.mc2_mask + # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. + quantized_x_for_share, dynamic_scale_for_share = None, None + from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod + if self.enable_multistream_moe: + if not self.rm_router_logits: + router_logits, _ = gate(hidden_states) + if hasattr(self.quant_method, "quant_method") and \ + isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) + if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce @@ -1419,6 +1434,8 @@ class AscendFusedMoE(FusedMoE): shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, mc2_mask=mc2_mask, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, ) if shared_experts: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index f1667d0..36549e7 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -33,6 +33,82 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version) +def apply_mlp_decode(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + Returns: + hidden_states: output hidden states after MLP. + """ + + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states + + def apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -124,6 +200,8 @@ def fused_experts_with_mc2( quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: assert mc2_mask is not None if log2phy is not None: @@ -186,18 +264,19 @@ def fused_experts_with_mc2( if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, expand_x) + npu_wait_tensor(shared_gate_up, expand_x) shared_act_out = shared_experts.act_fn( - (quantized_x_for_share, dynamic_scale_for_share)) + (shared_gate_up, shared_dequant_scale)) shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] - down_out_list = apply_mlp(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) + # `expand_x` will be disposed in the `apply_mlp` function + down_out_list = apply_mlp_decode(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) # moeCombine kwargs_mc2 = { @@ -745,6 +824,8 @@ class AscendW8A8DynamicFusedMoEMethod: log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -781,6 +862,16 @@ class AscendW8A8DynamicFusedMoEMethod: e_score_correction_bias=e_score_correction_bias, ) + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -788,8 +879,6 @@ class AscendW8A8DynamicFusedMoEMethod: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_weights = topk_weights.to(x.dtype) - - fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.AllGatherEP: return fused_experts_with_allgather( hidden_states=x, @@ -806,7 +895,7 @@ class AscendW8A8DynamicFusedMoEMethod: hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, + w1_scale=layer.w13_weight_scale_fp32, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, @@ -817,7 +906,9 @@ class AscendW8A8DynamicFusedMoEMethod: global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask", None), + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale) elif fused_moe_state in [ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast ]: @@ -860,6 +951,8 @@ class AscendW8A8DynamicFusedMoEMethod: torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(