From e56447033889ca95df512208cab22ef832bfdf07 Mon Sep 17 00:00:00 2001 From: cxcxflying <47177129+cxcxflying@users.noreply.github.com> Date: Tue, 13 May 2025 19:12:40 +0800 Subject: [PATCH] [Attention][Kernel]moe support for llama4 and mllama4 (#740) ### What this PR does / why we need it? moe support for llama4 and mllama4 in vllm-ascend ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? start sever: python -m vllm.entrypoints.openai.api_server --model /data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \ --max-num-seqs=256 \ --max-model-len=8192 \ --tensor-parallel-size=8 \ --block-size=128 \ --dtype bfloat16 \ --host=0.0.0.0 \ --port=8000 \ --gpu-memory-utilization=0.9 \ --trust-remote-code client: python online_server.py --model-path /data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct --image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip 7.242.108.253 --served-port 8000 --text "what is the content of this image?" result: {'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object': 'chat.completion', 'created': 1747056823, 'model': '/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'reasoning_content': None, 'content': 'The image depicts a tower, likely Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is white and has a distinctive shape, with a large sphere at the top and a long, thin spire extending from it. The branches of the cherry blossom tree are in the foreground, with pink flowers blooming on them. The background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:** White, spherical shape at the top, long thin spire\n', 'tool_calls': []}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}], 'usage': {'prompt_tokens': 2340, 'total_tokens': 2440, 'completion_tokens': 100, 'prompt_tokens_details': None}, 'prompt_logprobs': None} Signed-off-by: chenxu Co-authored-by: chenxu Co-authored-by: evian --- vllm_ascend/attention/attention.py | 1 + vllm_ascend/attention/attention_v1.py | 1 + vllm_ascend/ops/common_fused_moe.py | 16 ++++++++------- vllm_ascend/ops/fused_moe.py | 28 ++++++++++++++++++++++----- 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index cb4f745..d598822 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -708,6 +708,7 @@ class AscendAttentionBackendImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d594b8e..b5f6f39 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -174,6 +174,7 @@ class AscendAttentionBackendImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 1f2bf43..43c9517 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -55,13 +55,15 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, ) - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c912303..906d77c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -153,6 +153,7 @@ def fused_experts( topk_ids: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -191,6 +192,15 @@ def fused_experts( # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + if expert_map is not None: # Generate token indices and flatten token_indices = (torch.arange(num_tokens, @@ -292,6 +302,8 @@ def fused_experts( torch.zeros_like(weighted_down_out)).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: + scales = torch.ones_like( + topk_weights) if apply_router_weight_on_input else topk_weights # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( @@ -299,7 +311,7 @@ def fused_experts( skip1=None, skip2=None, bias=None, - scales=topk_weights, + scales=scales, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids, ) @@ -366,9 +378,6 @@ def select_experts( Raises: ValueError: If an unsupported scoring function is provided. """ - if custom_routing_function is not None: - raise NotImplementedError( - "Custom routing function is not supported now") if scoring_func == "softmax": # NOTE: vLLM use dtype=torch.float here @@ -405,9 +414,18 @@ def select_experts( k=top_k, dim=-1, sorted=False) - else: + elif custom_routing_function is None: topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + return topk_weights, topk_ids # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32)