# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # import torch from vllm.triton_utils import tl, triton @triton.jit def mean_kernel( input_ptr, output_ptr, input_stride0, input_stride1, input_stride2, output_stride0, output_stride1, M, # size before reduction dim N, # size of reduction dim K, # size after reduction dim BLOCK_SIZE: tl.constexpr, ): """ Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced. """ # Program ID gives us which output element we're computing pid = tl.program_id(0) # Compute output indices m_idx = pid // K k_idx = pid % K # Bounds check if m_idx >= M or k_idx >= K: return # Accumulate sum across reduction dimension acc = 0.0 for n_start in range(0, N, BLOCK_SIZE): n_offsets = n_start + tl.arange(0, BLOCK_SIZE) mask = n_offsets < N # Calculate input indices input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 # Load and accumulate vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) acc += tl.sum(vals) # Compute mean and store mean_val = acc / N output_idx = m_idx * output_stride0 + k_idx * output_stride1 tl.store(output_ptr + output_idx, mean_val) def mean_dim( input_: torch.Tensor, dim: int, keepdim: bool = False, dtype: torch.dtype = torch.float16, ) -> torch.Tensor: """ Triton implementation of torch.mean with single dimension reduction. Args: input: Input tensor dim: Single dimension along which to compute mean keepdim: Whether to keep the reduced dimension dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) Returns: Tensor with mean values along specified dimension """ # Validate inputs assert -input_.ndim <= dim < input_.ndim, ( f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions") # Handle negative dim if dim < 0: dim = dim + input_.ndim # Handle dtype if dtype is None: if input_.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: dtype = torch.float32 else: dtype = input_.dtype # Convert input to appropriate dtype if needed if input_.dtype != dtype: input_ = input_.to(dtype) # Get input shape and strides shape = list(input_.shape) # Calculate dimensions for kernel M = 1 for i in range(dim): M *= shape[i] N = shape[dim] K = 1 for i in range(dim + 1, len(shape)): K *= shape[i] # Reshape input to 3D view (M, N, K) input_3d = input_.reshape(M, N, K) # Create output shape if keepdim: output_shape = shape.copy() output_shape[dim] = 1 else: output_shape = shape[:dim] + shape[dim + 1:] # Create output tensor output = torch.empty(output_shape, dtype=dtype, device=input_.device) # Reshape output for kernel if keepdim: output_2d = output.reshape(M, 1, K).squeeze(1) else: output_2d = output.reshape(M, K) # Launch kernel grid = (M * K, ) BLOCK_SIZE = 1024 mean_kernel[grid]( input_3d, output_2d, input_3d.stride(0), input_3d.stride(1), input_3d.stride(2), output_2d.stride(0), output_2d.stride(1) if output_2d.ndim > 1 else 0, M, N, K, BLOCK_SIZE, ) return output def mean_batch_invariant( input_: torch.Tensor, dim: list[int], keepdim: bool = False, dtype: torch.dtype = torch.float16, ): assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" if len(dim) == 1: return mean_dim(input_, dim[0], keepdim=keepdim) else: assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32 }, ("only float types supported for now") if len(dim) == 0: dim = list(range(input_.ndim)) n_elems = 1 for d in dim: n_elems *= input_.shape[d] return torch.sum(input_, dim=dim, keepdim=keepdim, dtype=torch.float32).to(dtype or input_.dtype) / n_elems