Files
xc-llm-ascend/vllm_ascend/ops/triton/batch_invariant/matmul.py
Zhijun Chen 0ead5e8681 perf: adaptive block size selection in linear_persistent kernel (#6537)
### What this PR does / why we need it?

**Optimization:** Replaces fixed block sizes (128x128x128) in
`linear_persistent_kernel` with adaptive selection logic that considers:
- Matrix dimensions (M, N, K) 
- Device NPU vector core count
- Data type (float32 vs others)

**Why:** Fixed block sizes lead to suboptimal hardware utilization
across different matrix shapes. Adaptive sizing maximizes occupancy and
memory efficiency for varied workload patterns, improving throughput for
batch-invariant linear operations in LLM inference.

**Details:**
- Small matrices (M < 256): Size-proportional allocation
- Medium matrices (256 ≤ M < 1024): Balanced distribution based on grid
capacity
- Large matrices (M ≥ 1024): Optimized for dominant dimension

### Does this PR introduce _any_ user-facing change?

No. This is a performance optimization. The API and numerical results
remain unchanged; only kernel execution efficiency improves.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

Signed-off-by: DDCHY <843049740@qq.com>
Signed-off-by: zjchenn <zjchenn@gmail.com>
Co-authored-by: DDCHY <843049740@qq.com>
2026-02-04 21:36:26 +08:00

441 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from triton.runtime import driver # type: ignore
from vllm.triton_utils import tl, triton
@triton.jit
def matmul_bias_persistent_kernel(
# Input tensor pointers
x_ptr,
y_ptr,
bias_ptr,
output_ptr,
# Matrix dimensions
M,
N,
K,
# Stride information
stride_xm,
stride_xk, # Strides of x: [M, K]
stride_yk,
stride_yn, # Strides of y: [K, N]
stride_bias, # Stride of bias: [N]
stride_outm,
stride_outn, # Strides of output: [M, N]
# Whether to use bias
has_bias: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0) # Row block ID
pid_n = tl.program_id(1) # Column block ID
# Calculate the starting position of the current block in the matrix
rm_start = pid_m * BLOCK_M
rn_start = pid_n * BLOCK_N
# Create index ranges
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over the K dimension, processing BLOCK_K elements per iteration
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_start = k * BLOCK_K
# Calculate pointer offsets for x (row-major)
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
# Calculate pointer offsets for y (row-major)
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
# Create masks to prevent out-of-bounds access
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
# Load data chunks from global memory
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
# Compute matrix multiplication accumulation
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
# Add bias if the has_bias flag is set
if has_bias:
# Load bias values (broadcast to all rows)
bias_ptrs = bias_ptr + rn * stride_bias
bias_mask = rn < N
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
# Add bias to accumulator (automatic broadcasting)
acc += bias_vals[None, :]
# Calculate output pointer positions
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
# Store result to global memory
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
def matmul_persistent(x, y, bias=None):
"""
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [K, N]
bias: torch.Tensor, shape [N] or None
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
# Convert tensors to contiguous memory layout.
# This prevents transposed tensors from causing incorrect stride() values,
# which would lead to miscalculated data transfer volumes in subsequent operations.
x = x.contiguous()
y = y.contiguous()
M, K = x.shape
_, N = y.shape
# Validate bias shape (if not None)
if bias is not None:
assert bias.dim() == 1, "bias must be a 1D tensor"
assert y.shape[1] == bias.shape[0], (
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
)
# Allocate output tensor (same data type as x)
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
# Define block sizes (can be adjusted based on hardware)
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
# Calculate grid size (one thread per block)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
# Handle case when bias is None
if bias is None:
# Create a dummy bias tensor (will not be used as has_bias=False)
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
has_bias = False
bias_stride = 0
bias_to_pass = dummy_bias
else:
has_bias = True
bias_stride = bias.stride(0)
bias_to_pass = bias
# Launch kernel
matmul_bias_persistent_kernel[grid](
x,
y,
bias_to_pass,
output, # Input/Output tensors
M,
N,
K, # Matrix dimensions
x.stride(0),
x.stride(1), # Strides of x
y.stride(0),
y.stride(1), # Strides of y
bias_stride, # Stride of bias (0 if bias is None)
output.stride(0),
output.stride(1), # Strides of output
has_bias, # Flag indicating whether to use bias
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
@triton.jit
def linear_persistent_kernel(
a_ptr, # Pointer to tensor a, shape [M, K]
b_ptr, # Pointer to tensor b, shape [N, K]
c_ptr, # Pointer to output tensor c, shape [M, N]
M, # Number of rows in tensor a
N, # Number of rows in tensor b (number of columns in output c)
K, # Number of columns in both tensor a and tensor b
stride_am, # Stride of tensor a along dimension M (typically K)
stride_ak, # Stride of tensor a along dimension K (typically 1)
stride_bn, # Stride of tensor b along dimension N (typically K)
stride_bk, # Stride of tensor b along dimension K (typically 1)
stride_cm, # Stride of tensor c along dimension M (typically N)
stride_cn, # Stride of tensor c along dimension N (typically 1)
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr, # Block size for K dimension
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
):
# Get current program's 1D index (1D grid)
pid = tl.program_id(0)
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
# Loop over multiple blocks assigned to the current program
for block_index in range(pid, total_blocks, GRID_SIZE):
# Convert 1D block index to 2D coordinates (m_block, n_block)
m_block = block_index // NUM_BLOCKS_N
n_block = block_index % NUM_BLOCKS_N
# Calculate starting indices of the current output block
start_m = m_block * BLOCK_M
start_n = n_block * BLOCK_N
# Create row and column index ranges within the current block
m_indices = start_m + tl.arange(0, BLOCK_M)
n_indices = start_n + tl.arange(0, BLOCK_N)
# Create masks to handle boundaries
m_mask = m_indices < M
n_mask = n_indices < N
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over K dimension with step size BLOCK_K
for k_offset in range(0, K, BLOCK_K):
k_indices = k_offset + tl.arange(0, BLOCK_K)
k_mask = k_indices < K
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
b_vals_transposed = tl.trans(b_vals)
# Compute matrix multiplication: a_vals × b_vals_transposed
product = tl.dot(a_vals, b_vals_transposed)
acc += product
# Store result to output tensor c
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
def linear_persistent(x, y):
"""
Implement matrix multiplication with Triton: x @ y^T
Uses a fixed-size 1D grid
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [N, K]
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
M, K = x.shape
N, _ = y.shape
# Allocate output tensor (same data type as x)
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
# Define block sizes (can be adjusted based on hardware)
BLOCK_K = 256
if x.dtype == torch.float32:
BLOCK_K = BLOCK_K // 2
grid_size_div4 = grid_size // 4
if M == 0 or N == 0:
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 256
elif M < 256:
BLOCK_M = M
if grid_size * 128 <= N:
if M <= 128:
BLOCK_N = 256
else:
BLOCK_N = 128
elif grid_size * 32 >= N:
if M > N:
BLOCK_M = triton.cdiv(M, grid_size_div4)
BLOCK_N = triton.cdiv(N, 4)
else:
BLOCK_M = triton.cdiv(M, 4)
BLOCK_N = triton.cdiv(N, grid_size_div4)
else:
BLOCK_N = triton.next_power_of_2(triton.cdiv(N, grid_size))
elif M >= 256 and M < 1024:
if M < N:
BLOCK_M = 256
nums_m = triton.cdiv(M, BLOCK_M)
nums_n = grid_size // nums_m
if 128 * nums_n <= N:
BLOCK_N = 128
else:
BLOCK_N = min(triton.next_power_of_2(triton.cdiv(N, nums_n)), 128)
else:
BLOCK_M = min(triton.cdiv(M, grid_size_div4), 256)
BLOCK_N = min(triton.cdiv(N, 4), 128)
else:
if M > N:
BLOCK_M, BLOCK_N = 256, 128
nums_m = triton.cdiv(M, BLOCK_M)
nums_n = triton.cdiv(N, BLOCK_N)
if nums_m * nums_n < grid_size:
BLOCK_M = triton.cdiv(M, grid_size_div4)
BLOCK_N = triton.cdiv(N, 4)
else:
BLOCK_M, BLOCK_N = 128, 256
# Calculate number of blocks per dimension (ceil division)
num_blocks_m = triton.cdiv(M, BLOCK_M)
num_blocks_n = triton.cdiv(N, BLOCK_N)
# Set fixed 1D grid size
grid = (grid_size,)
# Launch kernel
linear_persistent_kernel[grid](
a_ptr=x,
b_ptr=y,
c_ptr=output,
M=M,
N=N,
K=K,
stride_am=x.stride(0),
stride_ak=x.stride(1),
stride_bn=y.stride(0),
stride_bk=y.stride(1),
stride_cm=output.stride(0),
stride_cn=output.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
GRID_SIZE=grid_size, # Fixed grid size
)
return output
def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as matmul
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
elif a.ndim == 3 and b.ndim == 2:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 2 and b.ndim == 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
return bmm_batch_invariant(a_expanded, b, out=out)
elif a.ndim == 4 and b.ndim == 4:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch, heads, seq_a, dim_a = a.shape
_, _, dim_b, seq_b = b.shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
# Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [batch, heads, seq_a, seq_b]
result = result_3d.reshape(batch, heads, seq_a, seq_b)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"got shapes {a.shape} and {b.shape}"
)
def linear_batch_invariant(input_, weight, bias=None):
output = linear_persistent(input_, weight)
if bias is not None:
output = output + bias
return output