diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 51e0cb9f..39200a86 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -225,7 +225,7 @@ def _select_experts_with_fusion_ops( norm_type=norm_type, # 0: softmax; 1: sigmoid # out_flag=False, # todo new api; should the third output be output # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, + routed_scaling_factor=routed_scaling_factor, eps=float(1e-20)) if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) @@ -304,3 +304,28 @@ def _native_select_experts( topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids + + +def zero_experts_compute( + expert_indices: torch.Tensor, + expert_scales: torch.Tensor, + num_experts: int, + zero_expert_type: str, + hidden_states: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if zero_expert_type == "identity": + zero_expert_mask = expert_indices < num_experts + zero_expert_scales = expert_scales.clone() + zero_expert_scales = torch.where(zero_expert_mask, 0.0, + zero_expert_scales) + + hidden_states = hidden_states.unsqueeze(1) + zero_expert_scales = zero_expert_scales.unsqueeze(2) + result = hidden_states * zero_expert_scales + result = result.sum(dim=1) + + normal_expert_mask = expert_indices >= num_experts + expert_indices = torch.where(normal_expert_mask, 0, expert_indices) + expert_scales = torch.where(normal_expert_mask, 0.0, expert_scales) + + return expert_indices, expert_scales, result diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 4a6a953e..efc709a3 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -35,7 +35,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config from vllm_ascend.eplb.utils import moe_load_async_stream from vllm_ascend.flash_common3_context import (get_flash_common3_context, set_flash_common3_context) -from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, + zero_experts_compute) from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method) @@ -92,7 +93,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): enable_force_load_balance: bool = False, shared_experts: Optional[Any] = None, **kwargs) -> torch.Tensor: - + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, @@ -107,6 +109,15 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) + if zero_expert_num > 0 and zero_expert_type is not None: + topk_ids, topk_weights, zero_expert_result = zero_experts_compute( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=x, + ) + topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. @@ -119,7 +130,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) moe_comm_method = get_forward_context().moe_comm_method - return moe_comm_method.fused_experts( + final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -131,6 +142,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) + if zero_expert_num > 0 and zero_expert_type is not None: + final_hidden_states += zero_expert_result + return final_hidden_states class AscendFusedMoE(FusedMoE): @@ -340,6 +354,7 @@ class AscendFusedMoE(FusedMoE): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 1cedda9c..1c952aa6 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -94,8 +94,6 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): hf_config = get_current_vllm_config().model_config.hf_config self.enable_shared_expert_dp = get_ascend_config( ).enable_shared_expert_dp - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - self.first_k_dense_replace = hf_config.first_k_dense_replace self.tp_size = get_tensor_model_parallel_world_size() self.layers = hf_config.num_hidden_layers if mla_modules.indexer is not None: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 8c8b7518..49a1a5ba 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -298,6 +298,12 @@ packed_modules_model_mapping = { "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] }, + "longcat_flash": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, } @@ -514,6 +520,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, @@ -524,9 +531,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): return self.quant_method.apply( layer, x, router_logits, top_k, renormalize, use_grouped_topk, global_num_experts, expert_map, topk_group, num_expert_group, - custom_routing_function, scoring_func, e_score_correction_bias, - is_prefill, enable_force_load_balance, log2phy, - global_redundant_expert_num, **kwargs) + custom_routing_function, scoring_func, routed_scaling_factor, + e_score_correction_bias, is_prefill, enable_force_load_balance, + log2phy, global_redundant_expert_num, **kwargs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/w4a16.py index d15fa25a..4fcc3380 100644 --- a/vllm_ascend/quantization/w4a16.py +++ b/vllm_ascend/quantization/w4a16.py @@ -199,6 +199,7 @@ class AscendW4A16FusedMoEMethod: num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = True, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 45a7bc18..3222f2ea 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -336,6 +336,7 @@ class AscendW4A8DynamicFusedMoEMethod: num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 986f6fd2..bebd807b 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -28,7 +28,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context -from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, + zero_experts_compute) from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz @@ -183,6 +184,7 @@ class AscendW8A8DynamicFusedMoEMethod: num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, @@ -194,8 +196,11 @@ class AscendW8A8DynamicFusedMoEMethod: pertoken_scale: Optional[Any] = None, **kwargs, ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + if zero_expert_num == 0 or zero_expert_type is None: + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \ + "Number of global experts mismatch (excluding redundancy)" if self.multistream_overlap_gate: fc3_context = get_flash_common3_context() @@ -213,10 +218,19 @@ class AscendW8A8DynamicFusedMoEMethod: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) assert topk_ids is not None assert topk_weights is not None + if zero_expert_num > 0 and zero_expert_type is not None: + topk_ids, topk_weights, zero_expert_result = zero_experts_compute( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=x, + ) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -253,7 +267,7 @@ class AscendW8A8DynamicFusedMoEMethod: fused_scale_flag = (get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1) - return moe_comm_method.fused_experts( + final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, w1=w1, @@ -271,6 +285,9 @@ class AscendW8A8DynamicFusedMoEMethod: dynamic_scale_for_share=dynamic_scale_for_share, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) + if zero_expert_num > 0 and zero_expert_type is not None: + final_hidden_states += zero_expert_result + return final_hidden_states def process_weights_after_loading(self, layer): layer.w13_weight.data = layer.w13_weight.data.transpose( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 319c7e41..b4b7e3a0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2240,9 +2240,10 @@ class NPUModelRunner(GPUModelRunner): kv_caches[layer_name] = kv_caches[target_layer_name] from vllm.v1.worker.utils import bind_kv_cache + num_attn_module = 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, - self.kv_caches) + self.kv_caches, num_attn_module) return kv_caches def _allocate_kv_cache_tensors(