[Feature] implement basic framework for batch invariant (#5517)
### What this PR does / why we need it?
This PR implement the basic framework for batch invariant, please see
https://github.com/vllm-project/vllm-ascend/issues/5487.
### Does this PR introduce _any_ user-facing change?
we reuse the function `vllm_is_batch_invariant` in vllm to judge if
batch invariant is enabled.
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Lord_of_Ironhill <suiweiyi@huawei.com>
Signed-off-by: zjchenn <zjchenn@gmail.com>
Signed-off-by: wangx700 <wangxin700@huawei.com>
Co-authored-by: Lord_of_Ironhill <suiweiyi@huawei.com>
Co-authored-by: zjchenn <zjchenn@gmail.com>
Co-authored-by: wangx700 <wangxin700@huawei.com>
This commit is contained in:
177
vllm_ascend/ops/triton/batch_invariant/mean.py
Normal file
177
vllm_ascend/ops/triton/batch_invariant/mean.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_stride0,
|
||||
input_stride1,
|
||||
input_stride2,
|
||||
output_stride0,
|
||||
output_stride1,
|
||||
M, # size before reduction dim
|
||||
N, # size of reduction dim
|
||||
K, # size after reduction dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for computing mean along a single dimension.
|
||||
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||
"""
|
||||
# Program ID gives us which output element we're computing
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Compute output indices
|
||||
m_idx = pid // K
|
||||
k_idx = pid % K
|
||||
|
||||
# Bounds check
|
||||
if m_idx >= M or k_idx >= K:
|
||||
return
|
||||
# Accumulate sum across reduction dimension
|
||||
acc = 0.0
|
||||
for n_start in range(0, N, BLOCK_SIZE):
|
||||
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
acc += tl.sum(vals)
|
||||
|
||||
# Compute mean and store
|
||||
mean_val = acc / N
|
||||
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(
|
||||
input_: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Single dimension along which to compute mean
|
||||
keepdim: Whether to keep the reduced dimension
|
||||
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
|
||||
|
||||
Returns:
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert -input_.ndim <= dim < input_.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input_.ndim} dimensions")
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
dim = dim + input_.ndim
|
||||
# Handle dtype
|
||||
if dtype is None:
|
||||
if input_.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = input_.dtype
|
||||
# Convert input to appropriate dtype if needed
|
||||
if input_.dtype != dtype:
|
||||
input_ = input_.to(dtype)
|
||||
|
||||
# Get input shape and strides
|
||||
shape = list(input_.shape)
|
||||
# Calculate dimensions for kernel
|
||||
M = 1
|
||||
for i in range(dim):
|
||||
M *= shape[i]
|
||||
|
||||
N = shape[dim]
|
||||
|
||||
K = 1
|
||||
for i in range(dim + 1, len(shape)):
|
||||
K *= shape[i]
|
||||
|
||||
# Reshape input to 3D view (M, N, K)
|
||||
input_3d = input_.reshape(M, N, K)
|
||||
|
||||
# Create output shape
|
||||
if keepdim:
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input_.device)
|
||||
|
||||
# Reshape output for kernel
|
||||
if keepdim:
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||
else:
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
input_3d,
|
||||
output_2d,
|
||||
input_3d.stride(0),
|
||||
input_3d.stride(1),
|
||||
input_3d.stride(2),
|
||||
output_2d.stride(0),
|
||||
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mean_batch_invariant(
|
||||
input_: torch.Tensor,
|
||||
dim: list[int],
|
||||
keepdim: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
if len(dim) == 1:
|
||||
return mean_dim(input_, dim[0], keepdim=keepdim)
|
||||
else:
|
||||
assert input_.dtype in {torch.float16, torch.bfloat16, torch.float32
|
||||
}, ("only float types supported for now")
|
||||
if len(dim) == 0:
|
||||
dim = list(range(input_.ndim))
|
||||
n_elems = 1
|
||||
for d in dim:
|
||||
n_elems *= input_.shape[d]
|
||||
return torch.sum(input_, dim=dim, keepdim=keepdim,
|
||||
dtype=torch.float32).to(dtype
|
||||
or input_.dtype) / n_elems
|
||||
Reference in New Issue
Block a user