Files
xc-llm-ascend/vllm_ascend/ops/triton/batch_invariant/matmul.py
huangning1995 ce11fd49f3 [Feature] Batch invariant torch.compile (#6107)
### What this PR does / why we need it?
Building upon https://github.com/vllm-project/vllm-ascend/pull/5517 to
enable batch-invariant in vllm-ascend, we observed that the performance
of BI in eager mode remains suboptimal.

This PR further integrates batch-invariant with torch.compile, which
improves inference performance by 350% when tested with Qwen3-0.6B.

### Does this PR introduce _any_ user-facing change?
Previously, enabling both aclgraph and Batch-Invariant would cause an
"ub overflow" error. This occurred because transposed input tensors
could produce incorrect stride() values.

To fix this, we now call .contiguous() on the input tensors before
passing them to Triton kernels. This ensures a contiguous memory layout
and prevents transposed tensors from causing incorrect stride
calculations.

### Test Plan
pytest -sv --durations=0
tests/e2e/singlecard/test_aclgraph_batch_invariant.py

### Test Result
```
============================================================================ slowest durations ============================================================================
87.37s call     tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle
77.39s call     tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
74.04s call     tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_logprobs_without_batch_invariance_should_fail
73.59s call     tests/e2e/singlecard/test_aclgraph_batch_invariant.py::test_simple_generation

(8 durations < 0.005s hidden.  Use -vv to show these durations.)
================================================================ 4 passed, 3 warnings in 312.45s (0:05:12) ================================================================
```
### Performance
export VLLM_BATCH_INVARIANT=1
vllm serve /home/Qwen3-0.6B \
--served-model-name qwen \
--port 8000 \
--max-num-seqs 256 \
--tensor-parallel-size 1 \
--max-model-len 5500 \
--max-num-batched-tokens 5500 \
--reasoning-parser qwen3 \
--gpu-memory-utilization 0.9 \
--compilation_config '{"cudagraph_mode":"FULL_DECODE_ONLY",
"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \
--additional-config
'{"ascend_scheduler_config":{"enabled":true},"enable_weight_nz_layout":true}'

vllm bench serve --served-model-name qwen --trust-remote-code --backend
vllm --model /home/Qwen3-0.6B/ --endpoint /v1/completions --dataset-name
random --random-input-len 512 --random-output-len 256 --num-prompts 800
--max-concurrency 8

torch.compile batch invariant performance:
```
============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  477.21    
Total input tokens:                      409600    
Total generated tokens:                  204800    
Request throughput (req/s):              1.68      
Output token throughput (tok/s):         429.16    
Peak output token throughput (tok/s):    472.00    
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          1287.48   
---------------Time to First Token----------------
Mean TTFT (ms):                          285.53    
Median TTFT (ms):                        312.70    
P99 TTFT (ms):                           324.22    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.59     
Median TPOT (ms):                        17.50     
P99 TPOT (ms):                           18.44     
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.59     
Median ITL (ms):                         17.45     
P99 ITL (ms):                            18.76     
==================================================
```
Eager
```
============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  1694.70   
Total input tokens:                      409600    
Total generated tokens:                  204800    
Request throughput (req/s):              0.47      
Output token throughput (tok/s):         120.85    
Peak output token throughput (tok/s):    136.00    
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          362.54    
---------------Time to First Token----------------
Mean TTFT (ms):                          164.29    
Median TTFT (ms):                        129.71    
P99 TTFT (ms):                           1961.66   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          65.81     
Median TPOT (ms):                        65.15     
P99 TPOT (ms):                           72.27     
---------------Inter-token Latency----------------
Mean ITL (ms):                           65.81     
Median ITL (ms):                         64.64     
P99 ITL (ms):                            75.72     
==================================================
```

- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: huangning1995 <huangning12@huawei.com>
2026-01-26 09:15:06 +08:00

397 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from triton.runtime import driver # type: ignore
from vllm.triton_utils import tl, triton
@triton.jit
def matmul_bias_persistent_kernel(
# Input tensor pointers
x_ptr,
y_ptr,
bias_ptr,
output_ptr,
# Matrix dimensions
M,
N,
K,
# Stride information
stride_xm,
stride_xk, # Strides of x: [M, K]
stride_yk,
stride_yn, # Strides of y: [K, N]
stride_bias, # Stride of bias: [N]
stride_outm,
stride_outn, # Strides of output: [M, N]
# Whether to use bias
has_bias: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0) # Row block ID
pid_n = tl.program_id(1) # Column block ID
# Calculate the starting position of the current block in the matrix
rm_start = pid_m * BLOCK_M
rn_start = pid_n * BLOCK_N
# Create index ranges
rm = rm_start + tl.arange(0, BLOCK_M) # Row index range [BLOCK_M]
rn = rn_start + tl.arange(0, BLOCK_N) # Column index range [BLOCK_N]
rk = tl.arange(0, BLOCK_K) # K dimension index range [BLOCK_K]
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over the K dimension, processing BLOCK_K elements per iteration
for k in range(0, tl.cdiv(K, BLOCK_K)):
k_start = k * BLOCK_K
# Calculate pointer offsets for x (row-major)
x_ptrs = x_ptr + rm[:, None] * stride_xm + (rk[None, :] + k_start) * stride_xk
# Calculate pointer offsets for y (row-major)
y_ptrs = y_ptr + (rk[:, None] + k_start) * stride_yk + rn[None, :] * stride_yn
# Create masks to prevent out-of-bounds access
x_mask = (rm[:, None] < M) & ((rk[None, :] + k_start) < K)
y_mask = ((rk[:, None] + k_start) < K) & (rn[None, :] < N)
# Load data chunks from global memory
x_chunk = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
y_chunk = tl.load(y_ptrs, mask=y_mask, other=0.0).to(tl.float32)
# Compute matrix multiplication accumulation
acc += tl.dot(x_chunk, y_chunk, allow_tf32=False)
# Add bias if the has_bias flag is set
if has_bias:
# Load bias values (broadcast to all rows)
bias_ptrs = bias_ptr + rn * stride_bias
bias_mask = rn < N
bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float32)
# Add bias to accumulator (automatic broadcasting)
acc += bias_vals[None, :]
# Calculate output pointer positions
out_ptrs = output_ptr + rm[:, None] * stride_outm + rn[None, :] * stride_outn
out_mask = (rm[:, None] < M) & (rn[None, :] < N)
# Store result to global memory
tl.store(out_ptrs, acc.to(out_ptrs.dtype.element_ty), mask=out_mask)
def matmul_persistent(x, y, bias=None):
"""
Implement matrix multiplication with optional bias using Triton: x @ y + bias (if bias is not None)
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [K, N]
bias: torch.Tensor, shape [N] or None
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[0], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[0]={y.shape[0]}"
# Convert tensors to contiguous memory layout.
# This prevents transposed tensors from causing incorrect stride() values,
# which would lead to miscalculated data transfer volumes in subsequent operations.
x = x.contiguous()
y = y.contiguous()
M, K = x.shape
_, N = y.shape
# Validate bias shape (if not None)
if bias is not None:
assert bias.dim() == 1, "bias must be a 1D tensor"
assert y.shape[1] == bias.shape[0], (
f"Bias dimension mismatch: y.shape[1]={y.shape[1]}, bias.shape[0]={bias.shape[0]}"
)
# Allocate output tensor (same data type as x)
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
# Define block sizes (can be adjusted based on hardware)
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
# Calculate grid size (one thread per block)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
# Handle case when bias is None
if bias is None:
# Create a dummy bias tensor (will not be used as has_bias=False)
dummy_bias = torch.empty(0, dtype=x.dtype, device=x.device)
has_bias = False
bias_stride = 0
bias_to_pass = dummy_bias
else:
has_bias = True
bias_stride = bias.stride(0)
bias_to_pass = bias
# Launch kernel
matmul_bias_persistent_kernel[grid](
x,
y,
bias_to_pass,
output, # Input/Output tensors
M,
N,
K, # Matrix dimensions
x.stride(0),
x.stride(1), # Strides of x
y.stride(0),
y.stride(1), # Strides of y
bias_stride, # Stride of bias (0 if bias is None)
output.stride(0),
output.stride(1), # Strides of output
has_bias, # Flag indicating whether to use bias
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
@triton.jit
def linear_persistent_kernel(
a_ptr, # Pointer to tensor a, shape [M, K]
b_ptr, # Pointer to tensor b, shape [N, K]
c_ptr, # Pointer to output tensor c, shape [M, N]
M, # Number of rows in tensor a
N, # Number of rows in tensor b (number of columns in output c)
K, # Number of columns in both tensor a and tensor b
stride_am, # Stride of tensor a along dimension M (typically K)
stride_ak, # Stride of tensor a along dimension K (typically 1)
stride_bn, # Stride of tensor b along dimension N (typically K)
stride_bk, # Stride of tensor b along dimension K (typically 1)
stride_cm, # Stride of tensor c along dimension M (typically N)
stride_cn, # Stride of tensor c along dimension N (typically 1)
BLOCK_M: tl.constexpr, # Block size for M dimension
BLOCK_N: tl.constexpr, # Block size for N dimension
BLOCK_K: tl.constexpr, # Block size for K dimension
NUM_BLOCKS_M: tl.constexpr, # New: Number of blocks in M dimension
NUM_BLOCKS_N: tl.constexpr, # New: Number of blocks in N dimension
GRID_SIZE: tl.constexpr, # New: Fixed 1D grid size
):
# Get current program's 1D index (1D grid)
pid = tl.program_id(0)
total_blocks = NUM_BLOCKS_M * NUM_BLOCKS_N # Total number of output blocks
# Loop over multiple blocks assigned to the current program
for block_index in range(pid, total_blocks, GRID_SIZE):
# Convert 1D block index to 2D coordinates (m_block, n_block)
m_block = block_index // NUM_BLOCKS_N
n_block = block_index % NUM_BLOCKS_N
# Calculate starting indices of the current output block
start_m = m_block * BLOCK_M
start_n = n_block * BLOCK_N
# Create row and column index ranges within the current block
m_indices = start_m + tl.arange(0, BLOCK_M)
n_indices = start_n + tl.arange(0, BLOCK_N)
# Create masks to handle boundaries
m_mask = m_indices < M
n_mask = n_indices < N
# Initialize accumulator to 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over K dimension with step size BLOCK_K
for k_offset in range(0, K, BLOCK_K):
k_indices = k_offset + tl.arange(0, BLOCK_K)
k_mask = k_indices < K
# Load block of tensor a: shape [BLOCK_M, BLOCK_K]
a_ptrs = a_ptr + m_indices[:, None] * stride_am + k_indices[None, :] * stride_ak
a_vals = tl.load(a_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
# Load block of tensor b: shape [BLOCK_N, BLOCK_K]
b_ptrs = b_ptr + n_indices[:, None] * stride_bn + k_indices[None, :] * stride_bk
b_vals = tl.load(b_ptrs, mask=n_mask[:, None] & k_mask[None, :], other=0.0)
# Explicitly transpose b matrix using tl.trans: shape becomes [BLOCK_K, BLOCK_N]
b_vals_transposed = tl.trans(b_vals)
# Compute matrix multiplication: a_vals × b_vals_transposed
product = tl.dot(a_vals, b_vals_transposed)
acc += product
# Store result to output tensor c
c_ptrs = c_ptr + m_indices[:, None] * stride_cm + n_indices[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=m_mask[:, None] & n_mask[None, :])
def linear_persistent(x, y):
"""
Implement matrix multiplication with Triton: x @ y^T
Uses a fixed-size 1D grid
Parameters:
x: torch.Tensor, shape [M, K]
y: torch.Tensor, shape [N, K]
Returns:
output: torch.Tensor, shape [M, N]
"""
# Validate input shapes
assert x.dim() == 2, "x must be a 2D tensor"
assert y.dim() == 2, "y must be a 2D tensor"
assert x.shape[1] == y.shape[1], f"Matrix dimension mismatch: x.shape[1]={x.shape[1]}, y.shape[1]={y.shape[1]}"
M, K = x.shape
N, _ = y.shape
# Allocate output tensor (same data type as x)
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
# Define block sizes (can be adjusted based on hardware)
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
# Calculate number of blocks per dimension (ceil division)
num_blocks_m = triton.cdiv(M, BLOCK_M)
num_blocks_n = triton.cdiv(N, BLOCK_N)
# Set fixed 1D grid size
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
grid = (grid_size,)
# Launch kernel
linear_persistent_kernel[grid](
a_ptr=x,
b_ptr=y,
c_ptr=output,
M=M,
N=N,
K=K,
stride_am=x.stride(0),
stride_ak=x.stride(1),
stride_bn=y.stride(0),
stride_bk=y.stride(1),
stride_cm=output.stride(0),
stride_cn=output.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
NUM_BLOCKS_M=num_blocks_m, # Number of blocks in M dimension
NUM_BLOCKS_N=num_blocks_n, # Number of blocks in N dimension
GRID_SIZE=grid_size, # Fixed grid size
)
return output
def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(f"bmm_batch_invariant expects 3D tensors, got shapes {a.shape} and {b.shape}")
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as matmul
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
elif a.ndim == 3 and b.ndim == 2:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 2 and b.ndim == 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
return bmm_batch_invariant(a_expanded, b, out=out)
elif a.ndim == 4 and b.ndim == 4:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch, heads, seq_a, dim_a = a.shape
_, _, dim_b, seq_b = b.shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
# Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [batch, heads, seq_a, seq_b]
result = result_3d.reshape(batch, heads, seq_a, seq_b)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"got shapes {a.shape} and {b.shape}"
)
def linear_batch_invariant(input_, weight, bias=None):
output = linear_persistent(input_, weight)
if bias is not None:
output = output + bias
return output