diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index cb4f745..d598822 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -708,6 +708,7 @@ class AscendAttentionBackendImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d594b8e..b5f6f39 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -174,6 +174,7 @@ class AscendAttentionBackendImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 1f2bf43..43c9517 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -55,13 +55,15 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, ) - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c912303..906d77c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -153,6 +153,7 @@ def fused_experts( topk_ids: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -191,6 +192,15 @@ def fused_experts( # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + if expert_map is not None: # Generate token indices and flatten token_indices = (torch.arange(num_tokens, @@ -292,6 +302,8 @@ def fused_experts( torch.zeros_like(weighted_down_out)).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: + scales = torch.ones_like( + topk_weights) if apply_router_weight_on_input else topk_weights # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( @@ -299,7 +311,7 @@ def fused_experts( skip1=None, skip2=None, bias=None, - scales=topk_weights, + scales=scales, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids, ) @@ -366,9 +378,6 @@ def select_experts( Raises: ValueError: If an unsupported scoring function is provided. """ - if custom_routing_function is not None: - raise NotImplementedError( - "Custom routing function is not supported now") if scoring_func == "softmax": # NOTE: vLLM use dtype=torch.float here @@ -405,9 +414,18 @@ def select_experts( k=top_k, dim=-1, sorted=False) - else: + elif custom_routing_function is None: topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32)