From 9d1452c74d174197ddda6c8ae0d58b7fc02dd7a7 Mon Sep 17 00:00:00 2001 From: ichaoren <36871991+ichaoren@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:19:18 +0800 Subject: [PATCH] [OPS]add split_qkv_tp_rmsnorm_rope ops (#7376) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR introduces a new fused Triton kernel, `split_qkv_tp_rmsnorm_rope` for Minimax-m2.5. The implementation includes two Triton kernels: 1. `_split_qkv_and_compute_local_qk_var_kernel`: Splits the QKV input and computes the local variance for RMSNorm. 2. `_apply_global_rmsnorm_kernel`: Applies global RMSNorm (considering TP all-reduce for variance) and Neox-style RoPE. ### Does this PR introduce _any_ user-facing change? Does not. ### How was this patch tested? ```python pytest tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_tp_rmsnorm_rope.py ``` ### Test Data A3 TP16 基线 | data | TTFT(ms) | TPOT(ms) | TPS | |------------|---------:|---------:|-------:| | 4k/1k@bs1 | 267.55 | 25.5 | 38.85 | | 4k/1k@bs4 | 542.4 | 26.51 | 148.06 | 测试线 | data | TTFT(ms) | TPOT(ms) | TPS | |------------|---------:|---------:|-------:| | 4k/1k@bs1 | 234.64 | 20.96 | 47.24 | | 4k/1k@bs4 | 508.36 | 22.16 | 176.69 | - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: xutianyi Co-authored-by: xutianyi --- .../triton/test_split_qkv_tp_rmsnorm_rope.py | 189 +++++++++++ vllm_ascend/ops/__init__.py | 1 + .../linearnorm/split_qkv_tp_rmsnorm_rope.py | 310 ++++++++++++++++++ vllm_ascend/patch/worker/patch_minimax_m2.py | 30 ++ 4 files changed, 530 insertions(+) create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_tp_rmsnorm_rope.py create mode 100644 vllm_ascend/ops/triton/linearnorm/split_qkv_tp_rmsnorm_rope.py diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_tp_rmsnorm_rope.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_tp_rmsnorm_rope.py new file mode 100644 index 00000000..f7299787 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_split_qkv_tp_rmsnorm_rope.py @@ -0,0 +1,189 @@ +import gc + +import numpy as np +import pytest +import torch + +import vllm_ascend.ops.register_custom_ops # noqa +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + +NUM_TOKENS = [1, 8, 32] +NUM_QKV_HEADS = [(6, 1), (8, 2)] +HEAD_DIMS = [128] +ROTARY_DIMS = [64, 128] +TP_WORLDS = [1] +EPS = [1e-6] +DTYPES = [torch.bfloat16] +SEEDS = [0] +DEVICES = [f"npu:{0}"] +DEFAULT_ATOL = 5e-2 +DEFAULT_RTOL = 5e-3 + + +def _build_rope(num_tokens, rotary_dim, dtype, device): + cos = torch.from_numpy( + np.random.uniform(0, 1, [num_tokens, rotary_dim // 2])).to(dtype).to(device) + sin = torch.from_numpy( + np.random.uniform(0, 1, [num_tokens, rotary_dim // 2])).to(dtype).to(device) + return cos.contiguous(), sin.contiguous() + + +def _apply_rope_neox(q, k, cos, sin, rotary_dim): + half = rotary_dim // 2 + cos = cos.to(torch.float32).unsqueeze(1) + sin = sin.to(torch.float32).unsqueeze(1) + + q_f32 = q.to(torch.float32) + k_f32 = k.to(torch.float32) + + q1 = q_f32[..., :half] + q2 = q_f32[..., half:rotary_dim] + q_rot = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1) + q_out = torch.cat([q_rot, q_f32[..., rotary_dim:]], dim=-1).to(q.dtype) + + k1 = k_f32[..., :half] + k2 = k_f32[..., half:rotary_dim] + k_rot = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1) + k_out = torch.cat([k_rot, k_f32[..., rotary_dim:]], dim=-1).to(k.dtype) + return q_out.contiguous(), k_out.contiguous() + + +def _fused_impl( + qkv, + q_weight, + k_weight, + q_hidden_size, + kv_hidden_size, + head_dim, + rotary_dim, + eps, + tp_world, + cos, + sin, +): + return torch.ops.vllm.split_qkv_tp_rmsnorm_rope( + input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=q_hidden_size, + kv_hidden_size=kv_hidden_size, + head_dim=head_dim, + rotary_dim=rotary_dim, + eps=eps, + tp_world=tp_world, + cos=cos, + sin=sin, + ) + + +def _reference_impl( + qkv, + q_weight, + k_weight, + q_hidden_size, + kv_hidden_size, + head_dim, + rotary_dim, + eps, + tp_world, + cos, + sin, +): + q, k, v = qkv.split([q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1) + orig_dtype = q.dtype + + q_f32 = q.to(torch.float32) + k_f32 = k.to(torch.float32) + q_var = q_f32.pow(2).mean(dim=-1, keepdim=True) + k_var = k_f32.pow(2).mean(dim=-1, keepdim=True) + + q_out = (q_f32 * torch.rsqrt(q_var + eps) * q_weight.to(torch.float32)).to( + orig_dtype) + k_out = (k_f32 * torch.rsqrt(k_var + eps) * k_weight.to(torch.float32)).to( + orig_dtype) + + q_3d = q_out.view(q.shape[0], -1, head_dim).contiguous() + k_3d = k_out.view(k.shape[0], -1, head_dim).contiguous() + q_3d, k_3d = _apply_rope_neox(q_3d, k_3d, cos.contiguous(), sin.contiguous(), + rotary_dim) + + return ( + q_3d.view(q.shape[0], q_hidden_size).contiguous(), + k_3d.view(k.shape[0], kv_hidden_size).contiguous(), + v.contiguous(), + ) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("tp_world", TP_WORLDS) +@pytest.mark.parametrize("eps", EPS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_split_qkv_tp_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, head_dim, + rotary_dim, tp_world, eps, dtype, seed, device): + torch.manual_seed(seed) + np.random.seed(seed) + torch.set_default_device(device) + init_device_properties_triton() + + q_hidden_size = num_q_heads * head_dim + kv_hidden_size = num_kv_heads * head_dim + + qkv = torch.randn(num_tokens, + q_hidden_size + kv_hidden_size * 2, + dtype=dtype, + device=device) + q_weight = torch.randn(q_hidden_size, dtype=torch.float32, + device=device) * 0.1 + 1.0 + k_weight = torch.randn(kv_hidden_size, dtype=torch.float32, + device=device) * 0.1 + 1.0 + cos, sin = _build_rope(num_tokens, rotary_dim, dtype, device) + + q_fused, k_fused, v_fused = _fused_impl( + qkv=qkv.clone(), + q_weight=q_weight.clone(), + k_weight=k_weight.clone(), + q_hidden_size=q_hidden_size, + kv_hidden_size=kv_hidden_size, + head_dim=head_dim, + rotary_dim=rotary_dim, + eps=eps, + tp_world=tp_world, + cos=cos, + sin=sin, + ) + q_ref, k_ref, v_ref = _reference_impl( + qkv=qkv.clone(), + q_weight=q_weight.clone(), + k_weight=k_weight.clone(), + q_hidden_size=q_hidden_size, + kv_hidden_size=kv_hidden_size, + head_dim=head_dim, + rotary_dim=rotary_dim, + eps=eps, + tp_world=tp_world, + cos=cos, + sin=sin, + ) + + torch.testing.assert_close(q_fused.to(torch.float32), + q_ref.to(torch.float32), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + torch.testing.assert_close(k_fused.to(torch.float32), + k_ref.to(torch.float32), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + torch.testing.assert_close(v_fused.to(torch.float32), + v_ref.to(torch.float32), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() \ No newline at end of file diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 95b498a5..4ef4f202 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -25,6 +25,7 @@ import vllm_ascend.ops.register_custom_ops # noqa if HAS_TRITON: import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_mrope + import vllm_ascend.ops.triton.linearnorm.split_qkv_tp_rmsnorm_rope import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_tp_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_tp_rmsnorm_rope.py new file mode 100644 index 00000000..af6dd93e --- /dev/null +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_tp_rmsnorm_rope.py @@ -0,0 +1,310 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import torch +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 64}), + triton.Config({"BLOCK": 128}), + triton.Config({"BLOCK": 256}), + ], + key=["q_cols", "k_cols"], +) +@triton.jit +def _split_qkv_and_compute_local_qk_var_kernel( + input_ptr, + q_out_ptr, + k_out_ptr, + v_out_ptr, + qk_var_ptr, + num_tokens, + q_cols: tl.constexpr, + k_cols: tl.constexpr, + PAD_Q: tl.constexpr, + PAD_K: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + num_programs = tl.num_programs(0) + input_row_stride = q_cols + 2 * k_cols + + for idx in tl.range(pid, num_tokens, num_programs): + input_base = input_ptr + idx * input_row_stride + + q_in_base = input_base + q_out_base = q_out_ptr + idx * q_cols + q_sum = tl.zeros((), dtype=tl.float32) + q_comp = tl.zeros((), dtype=tl.float32) + for q_off in tl.static_range(0, PAD_Q, BLOCK): + q_offsets = q_off + tl.arange(0, BLOCK) + q_mask = q_offsets < q_cols + q_vals = tl.load(q_in_base + q_offsets, mask=q_mask, other=0.0) + q_vals_f32 = q_vals.to(tl.float32) + q_chunk = tl.sum(q_vals_f32 * q_vals_f32, axis=0) + y = q_chunk - q_comp + t = q_sum + y + q_comp = (t - q_sum) - y + q_sum = t + tl.store(q_out_base + q_offsets, q_vals, mask=q_mask) + q_var = q_sum / q_cols + + k_in_base = input_base + q_cols + k_out_base = k_out_ptr + idx * k_cols + k_sum = tl.zeros((), dtype=tl.float32) + k_comp = tl.zeros((), dtype=tl.float32) + for k_off in tl.static_range(0, PAD_K, BLOCK): + k_offsets = k_off + tl.arange(0, BLOCK) + k_mask = k_offsets < k_cols + k_vals = tl.load(k_in_base + k_offsets, mask=k_mask, other=0.0) + k_vals_f32 = k_vals.to(tl.float32) + k_chunk = tl.sum(k_vals_f32 * k_vals_f32, axis=0) + y = k_chunk - k_comp + t = k_sum + y + k_comp = (t - k_sum) - y + k_sum = t + tl.store(k_out_base + k_offsets, k_vals, mask=k_mask) + k_var = k_sum / k_cols + + v_in_base = input_base + q_cols + k_cols + v_out_base = v_out_ptr + idx * k_cols + for v_off in tl.static_range(0, PAD_K, BLOCK): + v_offsets = v_off + tl.arange(0, BLOCK) + v_mask = v_offsets < k_cols + v_vals = tl.load(v_in_base + v_offsets, mask=v_mask, other=0.0) + tl.store(v_out_base + v_offsets, v_vals, mask=v_mask) + + tl.store(qk_var_ptr + idx * 2, q_var) + tl.store(qk_var_ptr + idx * 2 + 1, k_var) + + +@triton.jit +def _apply_global_rmsnorm_kernel( + q_ptr, + k_ptr, + cos_ptr, + sin_ptr, + cs_row_stride, + q_weight_ptr, + k_weight_ptr, + qk_global_var_ptr, + eps: tl.constexpr, + inv_tp_world: tl.constexpr, + num_tokens, + q_cols: tl.constexpr, + k_cols: tl.constexpr, + q_num_heads, + k_num_heads, + head_dim: tl.constexpr, + rotary_dim: tl.constexpr, + PAD_Q: tl.constexpr, + PAD_K: tl.constexpr, + PAD_QH: tl.constexpr, + PAD_KH: tl.constexpr, + PAD_HALF: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + num_programs = tl.num_programs(0) + + for idx in tl.range(pid, num_tokens, num_programs): + q_gv = tl.load(qk_global_var_ptr + idx * 2).to(tl.float32) * inv_tp_world + k_gv = tl.load(qk_global_var_ptr + idx * 2 + 1).to(tl.float32) * inv_tp_world + q_scale = 1.0 / tl.sqrt(q_gv + eps) + k_scale = 1.0 / tl.sqrt(k_gv + eps) + + q_base = q_ptr + idx * q_cols + q_offsets = tl.arange(0, PAD_Q) + q_mask = q_offsets < q_cols + q_vals = tl.load(q_base + q_offsets, mask=q_mask, other=0.0) + q_weight = tl.load(q_weight_ptr + q_offsets, mask=q_mask, other=1.0).to(tl.float32) + q_vals = (q_vals.to(tl.float32) * q_scale * q_weight).to(q_vals.dtype) + tl.store(q_base + q_offsets, q_vals, mask=q_mask) + + k_base = k_ptr + idx * k_cols + k_offsets = tl.arange(0, PAD_K) + k_mask = k_offsets < k_cols + k_vals = tl.load(k_base + k_offsets, mask=k_mask, other=0.0) + k_weight = tl.load(k_weight_ptr + k_offsets, mask=k_mask, other=1.0).to(tl.float32) + k_vals = (k_vals.to(tl.float32) * k_scale * k_weight).to(k_vals.dtype) + tl.store(k_base + k_offsets, k_vals, mask=k_mask) + + # Neox-style RoPE on the first rotary_dim dimensions of each head + half = rotary_dim // 2 + half_offsets = tl.arange(0, PAD_HALF) + half_mask = half_offsets < half + cos_row = tl.load( + cos_ptr + idx * cs_row_stride + half_offsets, + mask=half_mask, + other=0.0, + ).to(tl.float32) + sin_row = tl.load( + sin_ptr + idx * cs_row_stride + half_offsets, + mask=half_mask, + other=0.0, + ).to(tl.float32) + + qh_offsets = tl.arange(0, PAD_QH)[:, None] * head_dim + half_offsets[None, :] + qh_mask = (tl.arange(0, PAD_QH)[:, None] < q_num_heads) & half_mask[None, :] + qh_offsets_2 = qh_offsets + half + q1_raw = tl.load(q_base + qh_offsets, mask=qh_mask, other=0.0) + q2_raw = tl.load(q_base + qh_offsets_2, mask=qh_mask, other=0.0) + q1 = q1_raw.to(tl.float32) + q2 = q2_raw.to(tl.float32) + qn1 = q1 * cos_row[None, :] - q2 * sin_row[None, :] + qn2 = q2 * cos_row[None, :] + q1 * sin_row[None, :] + tl.store(q_base + qh_offsets, qn1.to(q1_raw.dtype), mask=qh_mask) + tl.store(q_base + qh_offsets_2, qn2.to(q2_raw.dtype), mask=qh_mask) + + kh_offsets = tl.arange(0, PAD_KH)[:, None] * head_dim + half_offsets[None, :] + kh_mask = (tl.arange(0, PAD_KH)[:, None] < k_num_heads) & half_mask[None, :] + kh_offsets_2 = kh_offsets + half + k1_raw = tl.load(k_base + kh_offsets, mask=kh_mask, other=0.0) + k2_raw = tl.load(k_base + kh_offsets_2, mask=kh_mask, other=0.0) + k1 = k1_raw.to(tl.float32) + k2 = k2_raw.to(tl.float32) + kn1 = k1 * cos_row[None, :] - k2 * sin_row[None, :] + kn2 = k2 * cos_row[None, :] + k1 * sin_row[None, :] + tl.store(k_base + kh_offsets, kn1.to(k1_raw.dtype), mask=kh_mask) + tl.store(k_base + kh_offsets_2, kn2.to(k2_raw.dtype), mask=kh_mask) + + +def split_qkv_tp_rmsnorm_rope_impl( + input: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + rotary_dim: int, + eps: float, + tp_world: int, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_tokens = input.shape[0] + input_2d = input.view(num_tokens, -1) + q = torch.empty(num_tokens, q_hidden_size, device=input.device, dtype=input.dtype) + k = torch.empty(num_tokens, kv_hidden_size, device=input.device, dtype=input.dtype) + v = torch.empty(num_tokens, kv_hidden_size, device=input.device, dtype=input.dtype) + + if num_tokens == 0: + return q, k, v + num_vectorcore = get_vectorcore_num() + grid = (min(num_tokens, num_vectorcore),) + q_cols = q_hidden_size + k_cols = kv_hidden_size + q_num_heads = q_hidden_size // head_dim + k_num_heads = kv_hidden_size // head_dim + + qk_var = torch.empty(num_tokens, 2, dtype=torch.float32, device=q.device) + _split_qkv_and_compute_local_qk_var_kernel[grid]( + input_2d, + q, + k, + v, + qk_var, + num_tokens, + q_cols, + k_cols, + triton.next_power_of_2(q_cols), + triton.next_power_of_2(k_cols), + ) + if tp_world > 1: + qk_var = tensor_model_parallel_all_reduce(qk_var) + + cos_2d = cos.view(num_tokens, -1) + sin_2d = sin.view(num_tokens, -1) + q_2d = q.view(num_tokens, -1) + k_2d = k.view(num_tokens, -1) + _apply_global_rmsnorm_kernel[grid]( + q_2d, + k_2d, + cos_2d, + sin_2d, + cos_2d.stride(0), + q_weight, + k_weight, + qk_var, + eps, + 1.0 / tp_world, + num_tokens, + q_cols, + k_cols, + q_num_heads, + k_num_heads, + head_dim, + rotary_dim, + triton.next_power_of_2(q_cols), + triton.next_power_of_2(k_cols), + triton.next_power_of_2(q_num_heads), + triton.next_power_of_2(k_num_heads), + triton.next_power_of_2(rotary_dim // 2), + ) + + return q, k, v + + +def split_qkv_tp_rmsnorm_rope_impl_fake( + input: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + rotary_dim: int, + eps: float, + tp_world: int, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_tokens = input.shape[0] + q_out = torch.empty( + num_tokens, + q_hidden_size, + device=input.device, + dtype=input.dtype, + ) + k_out = torch.empty( + num_tokens, + kv_hidden_size, + device=input.device, + dtype=input.dtype, + ) + v_out = torch.empty( + num_tokens, + kv_hidden_size, + device=input.device, + dtype=input.dtype, + ) + return q_out, k_out, v_out + + +direct_register_custom_op( + op_name="split_qkv_tp_rmsnorm_rope", + op_func=split_qkv_tp_rmsnorm_rope_impl, + fake_impl=split_qkv_tp_rmsnorm_rope_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/patch/worker/patch_minimax_m2.py b/vllm_ascend/patch/worker/patch_minimax_m2.py index bff94bed..54418401 100644 --- a/vllm_ascend/patch/worker/patch_minimax_m2.py +++ b/vllm_ascend/patch/worker/patch_minimax_m2.py @@ -28,6 +28,8 @@ from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE from vllm.platforms import current_platform +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice + FP8_DTYPES = tuple( getattr(torch, dtype_name) for dtype_name in ( @@ -172,3 +174,31 @@ def _patched_load_weights( MiniMaxM2Model.load_weights = _patched_load_weights + + +def _patch_forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, +) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + cos, sin = get_cos_and_sin_slice() + q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope( + input=qkv, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim), + eps=self.q_norm.variance_epsilon, + tp_world=self.q_norm.tp_world, + cos=cos, + sin=sin, + ) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +MiniMaxM2Attention.forward = _patch_forward