### What this PR does / why we need it?
Building upon https://github.com/vllm-project/vllm-ascend/pull/5517 to
enable batch-invariant in vllm-ascend, we observed that the performance
of BI in eager mode remains suboptimal.
This PR further integrates batch-invariant with torch.compile, which
improves inference performance by 350% when tested with Qwen3-0.6B.
### Does this PR introduce _any_ user-facing change?
Previously, enabling both aclgraph and Batch-Invariant would cause an
"ub overflow" error. This occurred because transposed input tensors
could produce incorrect stride() values.
To fix this, we now call .contiguous() on the input tensors before
passing them to Triton kernels. This ensures a contiguous memory layout
and prevents transposed tensors from causing incorrect stride
calculations.
### Test Plan
pytest -sv --durations=0
tests/e2e/singlecard/test_aclgraph_batch_invariant.py
### Test Result
```
============================================================================ slowest durations ============================================================================
87.37s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle
77.39s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
74.04s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_without_batch_invariance_should_fail
73.59s call tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_simple_generation
(8 durations < 0.005s hidden. Use -vv to show these durations.)
================================================================ 4 passed, 3 warnings in 312.45s (0:05:12) ================================================================
```
### Performance
export VLLM_BATCH_INVARIANT=1
vllm serve /home/Qwen3-0.6B \
--served-model-name qwen \
--port 8000 \
--max-num-seqs 256 \
--tensor-parallel-size 1 \
--max-model-len 5500 \
--max-num-batched-tokens 5500 \
--reasoning-parser qwen3 \
--gpu-memory-utilization 0.9 \
--compilation_config '{"cudagraph_mode":"FULL_DECODE_ONLY",
"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \
--additional-config
'{"ascend_scheduler_config":{"enabled":true},"enable_weight_nz_layout":true}'
vllm bench serve --served-model-name qwen --trust-remote-code --backend
vllm --model /home/Qwen3-0.6B/ --endpoint /v1/completions --dataset-name
random --random-input-len 512 --random-output-len 256 --num-prompts 800
--max-concurrency 8
torch.compile batch invariant performance:
```
============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 8
Benchmark duration (s): 477.21
Total input tokens: 409600
Total generated tokens: 204800
Request throughput (req/s): 1.68
Output token throughput (tok/s): 429.16
Peak output token throughput (tok/s): 472.00
Peak concurrent requests: 16.00
Total token throughput (tok/s): 1287.48
---------------Time to First Token----------------
Mean TTFT (ms): 285.53
Median TTFT (ms): 312.70
P99 TTFT (ms): 324.22
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 17.59
Median TPOT (ms): 17.50
P99 TPOT (ms): 18.44
---------------Inter-token Latency----------------
Mean ITL (ms): 17.59
Median ITL (ms): 17.45
P99 ITL (ms): 18.76
==================================================
```
Eager
```
============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 8
Benchmark duration (s): 1694.70
Total input tokens: 409600
Total generated tokens: 204800
Request throughput (req/s): 0.47
Output token throughput (tok/s): 120.85
Peak output token throughput (tok/s): 136.00
Peak concurrent requests: 16.00
Total token throughput (tok/s): 362.54
---------------Time to First Token----------------
Mean TTFT (ms): 164.29
Median TTFT (ms): 129.71
P99 TTFT (ms): 1961.66
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 65.81
Median TPOT (ms): 65.15
P99 TPOT (ms): 72.27
---------------Inter-token Latency----------------
Mean ITL (ms): 65.81
Median ITL (ms): 64.64
P99 ITL (ms): 75.72
==================================================
```
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: huangning1995 <huangning12@huawei.com>
397 lines
14 KiB
Python
397 lines
14 KiB
Python
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
|
||
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# This file is a part of the vllm-ascend project.
|
||
#
|
||
|
||
import torch
|
||
from triton.runtime import driver # type: ignore
|
||
from vllm.triton_utils import tl, triton
|
||
|
||
|
||
@triton.jit
|
||
def matmul_bias_persistent_kernel(
|
||
# Input tensor pointers
|
||
x_ptr,
|
||
y_ptr,
|
||
bias_ptr,
|
||
output_ptr,
|
||
# Matrix dimensions
|
||
M,
|
||
N,
|
||
K,
|
||
# Stride information
|
||
stride_xm,
|
||
stride_xk, # Strides of x: [M, K]
|
||
stride_yk,
|
||
stride_yn, # Strides of y: [K, N]
|
||
stride_bias, # Stride of bias: [N]
|
||
stride_outm,
|
||
stride_outn, # Strides of output: [M, N]
|
||
# Whether to use bias
|
||
has_bias: tl.constexpr,
|
||
# Block sizes
|
||
BLOCK_M: tl.constexpr,
|
||
BLOCK_N: tl.constexpr,
|
||
BLOCK_K: tl.constexpr,
|
||
):
|
||
pid_m = tl.program_id(0) # Row block ID
|
||
pid_n = tl.program_id(1) # Column block ID
|
||
|
||
# Calculate the starting position of the current block in the matrix
|
||
rm_start = pid_m * BLOCK_M
|
||
rn_start = pid_n * BLOCK_N
|
||
|
||
# Create index ranges
|
||
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
|
||
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
|
||
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
|
||
|
||
# Initialize accumulator to 0
|
||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||
|
||
# Loop over the K dimension, processing BLOCK_K elements per iteration
|
||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||
k_start = k * BLOCK_K
|
||
# Calculate pointer offsets for x (row-major)
|
||
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
|
||
# Calculate pointer offsets for y (row-major)
|
||
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
|
||
|
||
# Create masks to prevent out-of-bounds access
|
||
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
|
||
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
|
||
|
||
# Load data chunks from global memory
|
||
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
|
||
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
|
||
|
||
# Compute matrix multiplication accumulation
|
||
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
|
||
|
||
# Add bias if the has_bias flag is set
|
||
if has_bias:
|
||
# Load bias values (broadcast to all rows)
|
||
bias_ptrs = bias_ptr + rn * stride_bias
|
||
bias_mask = rn < N
|
||
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
|
||
# Add bias to accumulator (automatic broadcasting)
|
||
acc += bias_vals[None, :]
|
||
|
||
# Calculate output pointer positions
|
||
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
|
||
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||
|
||
# Store result to global memory
|
||
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
|
||
|
||
|
||
def matmul_persistent(x, y, bias=None):
|
||
"""
|
||
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
|
||
|
||
Parameters:
|
||
x: torch.Tensor, shape [M, K]
|
||
y: torch.Tensor, shape [K, N]
|
||
bias: torch.Tensor, shape [N] or None
|
||
|
||
Returns:
|
||
output: torch.Tensor, shape [M, N]
|
||
"""
|
||
# Validate input shapes
|
||
assert x.dim() == 2, "x must be a 2D tensor"
|
||
assert y.dim() == 2, "y must be a 2D tensor"
|
||
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
|
||
|
||
# Convert tensors to contiguous memory layout.
|
||
# This prevents transposed tensors from causing incorrect stride() values,
|
||
# which would lead to miscalculated data transfer volumes in subsequent operations.
|
||
x = x.contiguous()
|
||
y = y.contiguous()
|
||
|
||
M, K = x.shape
|
||
_, N = y.shape
|
||
# Validate bias shape (if not None)
|
||
if bias is not None:
|
||
assert bias.dim() == 1, "bias must be a 1D tensor"
|
||
assert y.shape[1] == bias.shape[0], (
|
||
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
|
||
)
|
||
|
||
# Allocate output tensor (same data type as x)
|
||
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
|
||
|
||
# Define block sizes (can be adjusted based on hardware)
|
||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
|
||
|
||
# Calculate grid size (one thread per block)
|
||
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
|
||
|
||
# Handle case when bias is None
|
||
if bias is None:
|
||
# Create a dummy bias tensor (will not be used as has_bias=False)
|
||
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
|
||
has_bias = False
|
||
bias_stride = 0
|
||
bias_to_pass = dummy_bias
|
||
else:
|
||
has_bias = True
|
||
bias_stride = bias.stride(0)
|
||
bias_to_pass = bias
|
||
# Launch kernel
|
||
matmul_bias_persistent_kernel[grid](
|
||
x,
|
||
y,
|
||
bias_to_pass,
|
||
output, # Input/Output tensors
|
||
M,
|
||
N,
|
||
K, # Matrix dimensions
|
||
x.stride(0),
|
||
x.stride(1), # Strides of x
|
||
y.stride(0),
|
||
y.stride(1), # Strides of y
|
||
bias_stride, # Stride of bias (0 if bias is None)
|
||
output.stride(0),
|
||
output.stride(1), # Strides of output
|
||
has_bias, # Flag indicating whether to use bias
|
||
BLOCK_M=BLOCK_M,
|
||
BLOCK_N=BLOCK_N,
|
||
BLOCK_K=BLOCK_K,
|
||
)
|
||
|
||
return output
|
||
|
||
|
||
@triton.jit
|
||
def linear_persistent_kernel(
|
||
a_ptr, # Pointer to tensor a, shape [M, K]
|
||
b_ptr, # Pointer to tensor b, shape [N, K]
|
||
c_ptr, # Pointer to output tensor c, shape [M, N]
|
||
M, # Number of rows in tensor a
|
||
N, # Number of rows in tensor b (number of columns in output c)
|
||
K, # Number of columns in both tensor a and tensor b
|
||
stride_am, # Stride of tensor a along dimension M (typically K)
|
||
stride_ak, # Stride of tensor a along dimension K (typically 1)
|
||
stride_bn, # Stride of tensor b along dimension N (typically K)
|
||
stride_bk, # Stride of tensor b along dimension K (typically 1)
|
||
stride_cm, # Stride of tensor c along dimension M (typically N)
|
||
stride_cn, # Stride of tensor c along dimension N (typically 1)
|
||
BLOCK_M: tl.constexpr, # Block size for M dimension
|
||
BLOCK_N: tl.constexpr, # Block size for N dimension
|
||
BLOCK_K: tl.constexpr, # Block size for K dimension
|
||
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
|
||
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
|
||
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
|
||
):
|
||
# Get current program's 1D index (1D grid)
|
||
pid = tl.program_id(0)
|
||
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
|
||
|
||
# Loop over multiple blocks assigned to the current program
|
||
for block_index in range(pid, total_blocks, GRID_SIZE):
|
||
# Convert 1D block index to 2D coordinates (m_block, n_block)
|
||
m_block = block_index // NUM_BLOCKS_N
|
||
n_block = block_index % NUM_BLOCKS_N
|
||
|
||
# Calculate starting indices of the current output block
|
||
start_m = m_block * BLOCK_M
|
||
start_n = n_block * BLOCK_N
|
||
|
||
# Create row and column index ranges within the current block
|
||
m_indices = start_m + tl.arange(0, BLOCK_M)
|
||
n_indices = start_n + tl.arange(0, BLOCK_N)
|
||
|
||
# Create masks to handle boundaries
|
||
m_mask = m_indices < M
|
||
n_mask = n_indices < N
|
||
|
||
# Initialize accumulator to 0
|
||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||
|
||
# Loop over K dimension with step size BLOCK_K
|
||
for k_offset in range(0, K, BLOCK_K):
|
||
k_indices = k_offset + tl.arange(0, BLOCK_K)
|
||
k_mask = k_indices < K
|
||
|
||
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
|
||
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
|
||
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
|
||
|
||
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
|
||
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
|
||
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
|
||
|
||
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
|
||
b_vals_transposed = tl.trans(b_vals)
|
||
|
||
# Compute matrix multiplication: a_vals × b_vals_transposed
|
||
product = tl.dot(a_vals, b_vals_transposed)
|
||
acc += product
|
||
# Store result to output tensor c
|
||
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
|
||
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
|
||
|
||
|
||
def linear_persistent(x, y):
|
||
"""
|
||
Implement matrix multiplication with Triton: x @ y^T
|
||
Uses a fixed-size 1D grid
|
||
|
||
Parameters:
|
||
x: torch.Tensor, shape [M, K]
|
||
y: torch.Tensor, shape [N, K]
|
||
|
||
Returns:
|
||
output: torch.Tensor, shape [M, N]
|
||
"""
|
||
# Validate input shapes
|
||
assert x.dim() == 2, "x must be a 2D tensor"
|
||
assert y.dim() == 2, "y must be a 2D tensor"
|
||
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
|
||
|
||
M, K = x.shape
|
||
N, _ = y.shape
|
||
|
||
# Allocate output tensor (same data type as x)
|
||
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
|
||
|
||
# Define block sizes (can be adjusted based on hardware)
|
||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
|
||
|
||
# Calculate number of blocks per dimension (ceil division)
|
||
num_blocks_m = triton.cdiv(M, BLOCK_M)
|
||
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
||
|
||
# Set fixed 1D grid size
|
||
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
||
grid = (grid_size,)
|
||
|
||
# Launch kernel
|
||
linear_persistent_kernel[grid](
|
||
a_ptr=x,
|
||
b_ptr=y,
|
||
c_ptr=output,
|
||
M=M,
|
||
N=N,
|
||
K=K,
|
||
stride_am=x.stride(0),
|
||
stride_ak=x.stride(1),
|
||
stride_bn=y.stride(0),
|
||
stride_bk=y.stride(1),
|
||
stride_cm=output.stride(0),
|
||
stride_cn=output.stride(1),
|
||
BLOCK_M=BLOCK_M,
|
||
BLOCK_N=BLOCK_N,
|
||
BLOCK_K=BLOCK_K,
|
||
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
|
||
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
|
||
GRID_SIZE=grid_size, # Fixed grid size
|
||
)
|
||
|
||
return output
|
||
|
||
|
||
def mm_batch_invariant(a, b):
|
||
return matmul_persistent(a, b)
|
||
|
||
|
||
def bmm_batch_invariant(a, b, *, out=None):
|
||
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
||
# Process each batch separately with our persistent kernel
|
||
if a.ndim == 3 and b.ndim == 3:
|
||
results = []
|
||
for i in range(a.shape[0]):
|
||
results.append(matmul_persistent(a[i], b[i]))
|
||
result = torch.stack(results, dim=0)
|
||
|
||
if out is not None:
|
||
out.copy_(result)
|
||
return out
|
||
return result
|
||
else:
|
||
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
|
||
|
||
|
||
def addmm_batch_invariant(bias, a, b):
|
||
return matmul_persistent(a, b, bias=bias)
|
||
|
||
|
||
def matmul_batch_invariant(a, b, *, out=None):
|
||
# torch.matmul can handle various dimensions
|
||
# For 2D x 2D, it's the same as matmul
|
||
if a.ndim == 2 and b.ndim == 2:
|
||
result = matmul_persistent(a, b)
|
||
if out is not None:
|
||
out.copy_(result)
|
||
return out
|
||
return result
|
||
elif a.ndim == 3 and b.ndim == 3:
|
||
# Handle batched case like bmm
|
||
return bmm_batch_invariant(a, b, out=out)
|
||
elif a.ndim == 3 and b.ndim == 2:
|
||
# Handle 3D x 2D: common for linear layers
|
||
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
|
||
# Reshape to 2D, do mm, reshape back
|
||
batch, seq, hidden = a.shape
|
||
a_2d = a.reshape(-1, hidden)
|
||
result_2d = matmul_persistent(a_2d, b)
|
||
result = result_2d.reshape(batch, seq, -1)
|
||
if out is not None:
|
||
out.copy_(result)
|
||
return out
|
||
return result
|
||
elif a.ndim == 2 and b.ndim == 3:
|
||
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
|
||
# By broadcasting `a` to 3D, we can reuse the batched matrix
|
||
# multiplication logic.
|
||
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
|
||
return bmm_batch_invariant(a_expanded, b, out=out)
|
||
elif a.ndim == 4 and b.ndim == 4:
|
||
# Handle 4D attention tensors: [batch, heads, seq, dim]
|
||
# Reshape to 3D, process, reshape back
|
||
batch, heads, seq_a, dim_a = a.shape
|
||
_, _, dim_b, seq_b = b.shape
|
||
|
||
# Reshape to [batch*heads, seq_a, dim_a]
|
||
a_3d = a.reshape(batch * heads, seq_a, dim_a)
|
||
b_3d = b.reshape(batch * heads, dim_b, seq_b)
|
||
|
||
# Do batched matmul
|
||
result_3d = bmm_batch_invariant(a_3d, b_3d)
|
||
|
||
# Reshape back to [batch, heads, seq_a, seq_b]
|
||
result = result_3d.reshape(batch, heads, seq_a, seq_b)
|
||
|
||
if out is not None:
|
||
out.copy_(result)
|
||
return out
|
||
return result
|
||
else:
|
||
raise ValueError(
|
||
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
|
||
f"3D x 2D, 2D x 3D, and 4D x 4D, "
|
||
f"got shapes {a.shape} and {b.shape}"
|
||
)
|
||
|
||
|
||
def linear_batch_invariant(input_, weight, bias=None):
|
||
output = linear_persistent(input_, weight)
|
||
|
||
if bias is not None:
|
||
output = output + bias
|
||
return output
|