2026-01-07 09:11:26 +08:00
|
|
|
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
|
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
|
#
|
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
|
#
|
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
#
|
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from triton.runtime import driver # type: ignore
|
|
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
|
def matmul_bias_persistent_kernel(
|
|
|
|
|
|
# Input tensor pointers
|
|
|
|
|
|
x_ptr,
|
|
|
|
|
|
y_ptr,
|
|
|
|
|
|
bias_ptr,
|
|
|
|
|
|
output_ptr,
|
|
|
|
|
|
# Matrix dimensions
|
|
|
|
|
|
M,
|
|
|
|
|
|
N,
|
|
|
|
|
|
K,
|
|
|
|
|
|
# Stride information
|
|
|
|
|
|
stride_xm,
|
|
|
|
|
|
stride_xk, # Strides of x: [M, K]
|
|
|
|
|
|
stride_yk,
|
|
|
|
|
|
stride_yn, # Strides of y: [K, N]
|
|
|
|
|
|
stride_bias, # Stride of bias: [N]
|
|
|
|
|
|
stride_outm,
|
|
|
|
|
|
stride_outn, # Strides of output: [M, N]
|
|
|
|
|
|
# Whether to use bias
|
|
|
|
|
|
has_bias: tl.constexpr,
|
|
|
|
|
|
# Block sizes
|
|
|
|
|
|
BLOCK_M: tl.constexpr,
|
|
|
|
|
|
BLOCK_N: tl.constexpr,
|
|
|
|
|
|
BLOCK_K: tl.constexpr,
|
|
|
|
|
|
):
|
|
|
|
|
|
pid_m = tl.program_id(0) # Row block ID
|
|
|
|
|
|
pid_n = tl.program_id(1) # Column block ID
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate the starting position of the current block in the matrix
|
|
|
|
|
|
rm_start = pid_m * BLOCK_M
|
|
|
|
|
|
rn_start = pid_n * BLOCK_N
|
|
|
|
|
|
|
|
|
|
|
|
# Create index ranges
|
|
|
|
|
|
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
|
|
|
|
|
|
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
|
|
|
|
|
|
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize accumulator to 0
|
|
|
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
|
|
|
|
|
|
|
|
|
|
# Loop over the K dimension, processing BLOCK_K elements per iteration
|
|
|
|
|
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
|
|
|
|
|
k_start = k * BLOCK_K
|
|
|
|
|
|
# Calculate pointer offsets for x (row-major)
|
2026-01-23 14:59:19 +08:00
|
|
|
|
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
|
2026-01-07 09:11:26 +08:00
|
|
|
|
# Calculate pointer offsets for y (row-major)
|
2026-01-23 14:59:19 +08:00
|
|
|
|
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
# Create masks to prevent out-of-bounds access
|
|
|
|
|
|
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
|
|
|
|
|
|
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
|
|
|
|
|
|
|
|
|
|
|
|
# Load data chunks from global memory
|
|
|
|
|
|
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
|
|
|
|
|
|
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
|
|
|
|
|
|
|
|
|
|
|
|
# Compute matrix multiplication accumulation
|
|
|
|
|
|
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
|
|
|
|
|
|
|
|
|
|
|
|
# Add bias if the has_bias flag is set
|
|
|
|
|
|
if has_bias:
|
|
|
|
|
|
# Load bias values (broadcast to all rows)
|
|
|
|
|
|
bias_ptrs = bias_ptr + rn * stride_bias
|
|
|
|
|
|
bias_mask = rn < N
|
2026-01-23 14:59:19 +08:00
|
|
|
|
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
|
2026-01-07 09:11:26 +08:00
|
|
|
|
# Add bias to accumulator (automatic broadcasting)
|
|
|
|
|
|
acc += bias_vals[None, :]
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate output pointer positions
|
2026-01-23 14:59:19 +08:00
|
|
|
|
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
|
2026-01-07 09:11:26 +08:00
|
|
|
|
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
|
|
|
|
|
|
|
|
|
|
|
|
# Store result to global memory
|
|
|
|
|
|
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def matmul_persistent(x, y, bias=None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
|
2026-01-23 14:59:19 +08:00
|
|
|
|
|
2026-01-07 09:11:26 +08:00
|
|
|
|
Parameters:
|
|
|
|
|
|
x: torch.Tensor, shape [M, K]
|
|
|
|
|
|
y: torch.Tensor, shape [K, N]
|
|
|
|
|
|
bias: torch.Tensor, shape [N] or None
|
2026-01-23 14:59:19 +08:00
|
|
|
|
|
2026-01-07 09:11:26 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
output: torch.Tensor, shape [M, N]
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Validate input shapes
|
|
|
|
|
|
assert x.dim() == 2, "x must be a 2D tensor"
|
|
|
|
|
|
assert y.dim() == 2, "y must be a 2D tensor"
|
2026-01-23 14:59:19 +08:00
|
|
|
|
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
M, K = x.shape
|
|
|
|
|
|
_, N = y.shape
|
|
|
|
|
|
# Validate bias shape (if not None)
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
assert bias.dim() == 1, "bias must be a 1D tensor"
|
2026-01-23 14:59:19 +08:00
|
|
|
|
assert y.shape[1] == bias.shape[0], (
|
|
|
|
|
|
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
|
|
|
|
|
)
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
# Allocate output tensor (same data type as x)
|
|
|
|
|
|
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
|
|
|
|
|
|
|
|
|
|
|
|
# Define block sizes (can be adjusted based on hardware)
|
|
|
|
|
|
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate grid size (one thread per block)
|
|
|
|
|
|
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
|
|
|
|
|
|
|
|
|
|
|
|
# Handle case when bias is None
|
|
|
|
|
|
if bias is None:
|
|
|
|
|
|
# Create a dummy bias tensor (will not be used as has_bias=False)
|
|
|
|
|
|
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
|
|
|
|
|
|
has_bias = False
|
|
|
|
|
|
bias_stride = 0
|
|
|
|
|
|
bias_to_pass = dummy_bias
|
|
|
|
|
|
else:
|
|
|
|
|
|
has_bias = True
|
|
|
|
|
|
bias_stride = bias.stride(0)
|
|
|
|
|
|
bias_to_pass = bias
|
|
|
|
|
|
# Launch kernel
|
|
|
|
|
|
matmul_bias_persistent_kernel[grid](
|
|
|
|
|
|
x,
|
|
|
|
|
|
y,
|
|
|
|
|
|
bias_to_pass,
|
|
|
|
|
|
output, # Input/Output tensors
|
|
|
|
|
|
M,
|
|
|
|
|
|
N,
|
|
|
|
|
|
K, # Matrix dimensions
|
|
|
|
|
|
x.stride(0),
|
|
|
|
|
|
x.stride(1), # Strides of x
|
|
|
|
|
|
y.stride(0),
|
|
|
|
|
|
y.stride(1), # Strides of y
|
|
|
|
|
|
bias_stride, # Stride of bias (0 if bias is None)
|
|
|
|
|
|
output.stride(0),
|
|
|
|
|
|
output.stride(1), # Strides of output
|
|
|
|
|
|
has_bias, # Flag indicating whether to use bias
|
|
|
|
|
|
BLOCK_M=BLOCK_M,
|
|
|
|
|
|
BLOCK_N=BLOCK_N,
|
|
|
|
|
|
BLOCK_K=BLOCK_K,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
|
def linear_persistent_kernel(
|
2026-01-23 14:59:19 +08:00
|
|
|
|
a_ptr, # Pointer to tensor a, shape [M, K]
|
|
|
|
|
|
b_ptr, # Pointer to tensor b, shape [N, K]
|
|
|
|
|
|
c_ptr, # Pointer to output tensor c, shape [M, N]
|
|
|
|
|
|
M, # Number of rows in tensor a
|
|
|
|
|
|
N, # Number of rows in tensor b (number of columns in output c)
|
|
|
|
|
|
K, # Number of columns in both tensor a and tensor b
|
|
|
|
|
|
stride_am, # Stride of tensor a along dimension M (typically K)
|
|
|
|
|
|
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
|
|
|
|
|
stride_bn, # Stride of tensor b along dimension N (typically K)
|
|
|
|
|
|
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
|
|
|
|
|
stride_cm, # Stride of tensor c along dimension M (typically N)
|
|
|
|
|
|
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
|
|
|
|
|
BLOCK_M: tl.constexpr, # Block size for M dimension
|
|
|
|
|
|
BLOCK_N: tl.constexpr, # Block size for N dimension
|
|
|
|
|
|
BLOCK_K: tl.constexpr, # Block size for K dimension
|
|
|
|
|
|
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
|
|
|
|
|
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
|
|
|
|
|
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
2026-01-07 09:11:26 +08:00
|
|
|
|
):
|
|
|
|
|
|
# Get current program's 1D index (1D grid)
|
|
|
|
|
|
pid = tl.program_id(0)
|
|
|
|
|
|
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
|
|
|
|
|
|
|
|
|
|
|
|
# Loop over multiple blocks assigned to the current program
|
|
|
|
|
|
for block_index in range(pid, total_blocks, GRID_SIZE):
|
|
|
|
|
|
# Convert 1D block index to 2D coordinates (m_block, n_block)
|
|
|
|
|
|
m_block = block_index // NUM_BLOCKS_N
|
|
|
|
|
|
n_block = block_index % NUM_BLOCKS_N
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate starting indices of the current output block
|
|
|
|
|
|
start_m = m_block * BLOCK_M
|
|
|
|
|
|
start_n = n_block * BLOCK_N
|
|
|
|
|
|
|
|
|
|
|
|
# Create row and column index ranges within the current block
|
|
|
|
|
|
m_indices = start_m + tl.arange(0, BLOCK_M)
|
|
|
|
|
|
n_indices = start_n + tl.arange(0, BLOCK_N)
|
|
|
|
|
|
|
|
|
|
|
|
# Create masks to handle boundaries
|
|
|
|
|
|
m_mask = m_indices < M
|
|
|
|
|
|
n_mask = n_indices < N
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize accumulator to 0
|
|
|
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
|
|
|
|
|
|
|
|
|
|
# Loop over K dimension with step size BLOCK_K
|
|
|
|
|
|
for k_offset in range(0, K, BLOCK_K):
|
|
|
|
|
|
k_indices = k_offset + tl.arange(0, BLOCK_K)
|
|
|
|
|
|
k_mask = k_indices < K
|
|
|
|
|
|
|
|
|
|
|
|
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
|
2026-01-23 14:59:19 +08:00
|
|
|
|
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
|
|
|
|
|
|
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
|
2026-01-23 14:59:19 +08:00
|
|
|
|
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
|
|
|
|
|
|
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
|
|
|
|
|
|
b_vals_transposed = tl.trans(b_vals)
|
|
|
|
|
|
|
|
|
|
|
|
# Compute matrix multiplication: a_vals × b_vals_transposed
|
|
|
|
|
|
product = tl.dot(a_vals, b_vals_transposed)
|
|
|
|
|
|
acc += product
|
|
|
|
|
|
# Store result to output tensor c
|
2026-01-23 14:59:19 +08:00
|
|
|
|
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
|
2026-01-07 09:11:26 +08:00
|
|
|
|
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_persistent(x, y):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Implement matrix multiplication with Triton: x @ y^T
|
|
|
|
|
|
Uses a fixed-size 1D grid
|
2026-01-23 14:59:19 +08:00
|
|
|
|
|
2026-01-07 09:11:26 +08:00
|
|
|
|
Parameters:
|
|
|
|
|
|
x: torch.Tensor, shape [M, K]
|
|
|
|
|
|
y: torch.Tensor, shape [N, K]
|
2026-01-23 14:59:19 +08:00
|
|
|
|
|
2026-01-07 09:11:26 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
output: torch.Tensor, shape [M, N]
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Validate input shapes
|
|
|
|
|
|
assert x.dim() == 2, "x must be a 2D tensor"
|
|
|
|
|
|
assert y.dim() == 2, "y must be a 2D tensor"
|
2026-01-23 14:59:19 +08:00
|
|
|
|
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
M, K = x.shape
|
|
|
|
|
|
N, _ = y.shape
|
|
|
|
|
|
|
|
|
|
|
|
# Allocate output tensor (same data type as x)
|
|
|
|
|
|
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
|
|
|
|
|
|
|
|
|
|
|
|
# Define block sizes (can be adjusted based on hardware)
|
|
|
|
|
|
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate number of blocks per dimension (ceil division)
|
|
|
|
|
|
num_blocks_m = triton.cdiv(M, BLOCK_M)
|
|
|
|
|
|
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
|
|
|
|
|
|
|
|
|
|
|
# Set fixed 1D grid size
|
2026-01-23 14:59:19 +08:00
|
|
|
|
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
|
|
|
|
|
grid = (grid_size,)
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
# Launch kernel
|
|
|
|
|
|
linear_persistent_kernel[grid](
|
|
|
|
|
|
a_ptr=x,
|
|
|
|
|
|
b_ptr=y,
|
|
|
|
|
|
c_ptr=output,
|
|
|
|
|
|
M=M,
|
|
|
|
|
|
N=N,
|
|
|
|
|
|
K=K,
|
|
|
|
|
|
stride_am=x.stride(0),
|
|
|
|
|
|
stride_ak=x.stride(1),
|
|
|
|
|
|
stride_bn=y.stride(0),
|
|
|
|
|
|
stride_bk=y.stride(1),
|
|
|
|
|
|
stride_cm=output.stride(0),
|
|
|
|
|
|
stride_cn=output.stride(1),
|
|
|
|
|
|
BLOCK_M=BLOCK_M,
|
|
|
|
|
|
BLOCK_N=BLOCK_N,
|
|
|
|
|
|
BLOCK_K=BLOCK_K,
|
|
|
|
|
|
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
|
|
|
|
|
|
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
|
|
|
|
|
|
GRID_SIZE=grid_size, # Fixed grid size
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mm_batch_invariant(a, b):
|
|
|
|
|
|
return matmul_persistent(a, b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bmm_batch_invariant(a, b, *, out=None):
|
|
|
|
|
|
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
|
|
|
|
|
# Process each batch separately with our persistent kernel
|
|
|
|
|
|
if a.ndim == 3 and b.ndim == 3:
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for i in range(a.shape[0]):
|
|
|
|
|
|
results.append(matmul_persistent(a[i], b[i]))
|
|
|
|
|
|
result = torch.stack(results, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
if out is not None:
|
|
|
|
|
|
out.copy_(result)
|
|
|
|
|
|
return out
|
|
|
|
|
|
return result
|
|
|
|
|
|
else:
|
2026-01-23 14:59:19 +08:00
|
|
|
|
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def addmm_batch_invariant(bias, a, b):
|
|
|
|
|
|
return matmul_persistent(a, b, bias=bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def matmul_batch_invariant(a, b, *, out=None):
|
|
|
|
|
|
# torch.matmul can handle various dimensions
|
|
|
|
|
|
# For 2D x 2D, it's the same as matmul
|
|
|
|
|
|
if a.ndim == 2 and b.ndim == 2:
|
|
|
|
|
|
result = matmul_persistent(a, b)
|
|
|
|
|
|
if out is not None:
|
|
|
|
|
|
out.copy_(result)
|
|
|
|
|
|
return out
|
|
|
|
|
|
return result
|
|
|
|
|
|
elif a.ndim == 3 and b.ndim == 3:
|
|
|
|
|
|
# Handle batched case like bmm
|
|
|
|
|
|
return bmm_batch_invariant(a, b, out=out)
|
|
|
|
|
|
elif a.ndim == 3 and b.ndim == 2:
|
|
|
|
|
|
# Handle 3D x 2D: common for linear layers
|
|
|
|
|
|
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
|
|
|
|
|
|
# Reshape to 2D, do mm, reshape back
|
|
|
|
|
|
batch, seq, hidden = a.shape
|
|
|
|
|
|
a_2d = a.reshape(-1, hidden)
|
|
|
|
|
|
result_2d = matmul_persistent(a_2d, b)
|
|
|
|
|
|
result = result_2d.reshape(batch, seq, -1)
|
|
|
|
|
|
if out is not None:
|
|
|
|
|
|
out.copy_(result)
|
|
|
|
|
|
return out
|
|
|
|
|
|
return result
|
|
|
|
|
|
elif a.ndim == 2 and b.ndim == 3:
|
|
|
|
|
|
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
|
|
|
|
|
|
# By broadcasting `a` to 3D, we can reuse the batched matrix
|
|
|
|
|
|
# multiplication logic.
|
|
|
|
|
|
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
|
|
|
|
|
|
return bmm_batch_invariant(a_expanded, b, out=out)
|
|
|
|
|
|
elif a.ndim == 4 and b.ndim == 4:
|
|
|
|
|
|
# Handle 4D attention tensors: [batch, heads, seq, dim]
|
|
|
|
|
|
# Reshape to 3D, process, reshape back
|
|
|
|
|
|
batch, heads, seq_a, dim_a = a.shape
|
|
|
|
|
|
_, _, dim_b, seq_b = b.shape
|
|
|
|
|
|
|
|
|
|
|
|
# Reshape to [batch*heads, seq_a, dim_a]
|
|
|
|
|
|
a_3d = a.reshape(batch * heads, seq_a, dim_a)
|
|
|
|
|
|
b_3d = b.reshape(batch * heads, dim_b, seq_b)
|
|
|
|
|
|
|
|
|
|
|
|
# Do batched matmul
|
|
|
|
|
|
result_3d = bmm_batch_invariant(a_3d, b_3d)
|
|
|
|
|
|
|
|
|
|
|
|
# Reshape back to [batch, heads, seq_a, seq_b]
|
|
|
|
|
|
result = result_3d.reshape(batch, heads, seq_a, seq_b)
|
|
|
|
|
|
|
|
|
|
|
|
if out is not None:
|
|
|
|
|
|
out.copy_(result)
|
|
|
|
|
|
return out
|
|
|
|
|
|
return result
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
|
|
|
|
|
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
2026-01-23 14:59:19 +08:00
|
|
|
|
f"got shapes {a.shape} and {b.shape}"
|
|
|
|
|
|
)
|
2026-01-07 09:11:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_batch_invariant(input_, weight, bias=None):
|
|
|
|
|
|
output = linear_persistent(input_, weight)
|
|
|
|
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
|
output = output + bias
|
|
|
|
|
|
return output
|