perf: adaptive block size selection in linear_persistent kernel (#6537)
### What this PR does / why we need it? **Optimization:** Replaces fixed block sizes (128x128x128) in `linear_persistent_kernel` with adaptive selection logic that considers: - Matrix dimensions (M, N, K) - Device NPU vector core count - Data type (float32 vs others) **Why:** Fixed block sizes lead to suboptimal hardware utilization across different matrix shapes. Adaptive sizing maximizes occupancy and memory efficiency for varied workload patterns, improving throughput for batch-invariant linear operations in LLM inference. **Details:** - Small matrices (M < 256): Size-proportional allocation - Medium matrices (256 ≤ M < 1024): Balanced distribution based on grid capacity - Large matrices (M ≥ 1024): Optimized for dominant dimension ### Does this PR introduce _any_ user-facing change? No. This is a performance optimization. The API and numerical results remain unchanged; only kernel execution efficiency improves. ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: DDCHY <843049740@qq.com> Signed-off-by: zjchenn <zjchenn@gmail.com> Co-authored-by: DDCHY <843049740@qq.com>
This commit is contained in:
@@ -269,15 +269,59 @@ def linear_persistent(x, y):
|
||||
# Allocate output tensor (same data type as x)
|
||||
output = torch.zeros((M, N), dtype=x.dtype, device=x.device)
|
||||
|
||||
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
||||
|
||||
# Define block sizes (can be adjusted based on hardware)
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128
|
||||
BLOCK_K = 256
|
||||
if x.dtype == torch.float32:
|
||||
BLOCK_K = BLOCK_K // 2
|
||||
grid_size_div4 = grid_size // 4
|
||||
if M == 0 or N == 0:
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 256
|
||||
elif M < 256:
|
||||
BLOCK_M = M
|
||||
if grid_size * 128 <= N:
|
||||
if M <= 128:
|
||||
BLOCK_N = 256
|
||||
else:
|
||||
BLOCK_N = 128
|
||||
elif grid_size * 32 >= N:
|
||||
if M > N:
|
||||
BLOCK_M = triton.cdiv(M, grid_size_div4)
|
||||
BLOCK_N = triton.cdiv(N, 4)
|
||||
else:
|
||||
BLOCK_M = triton.cdiv(M, 4)
|
||||
BLOCK_N = triton.cdiv(N, grid_size_div4)
|
||||
else:
|
||||
BLOCK_N = triton.next_power_of_2(triton.cdiv(N, grid_size))
|
||||
elif M >= 256 and M < 1024:
|
||||
if M < N:
|
||||
BLOCK_M = 256
|
||||
nums_m = triton.cdiv(M, BLOCK_M)
|
||||
nums_n = grid_size // nums_m
|
||||
if 128 * nums_n <= N:
|
||||
BLOCK_N = 128
|
||||
else:
|
||||
BLOCK_N = min(triton.next_power_of_2(triton.cdiv(N, nums_n)), 128)
|
||||
else:
|
||||
BLOCK_M = min(triton.cdiv(M, grid_size_div4), 256)
|
||||
BLOCK_N = min(triton.cdiv(N, 4), 128)
|
||||
else:
|
||||
if M > N:
|
||||
BLOCK_M, BLOCK_N = 256, 128
|
||||
nums_m = triton.cdiv(M, BLOCK_M)
|
||||
nums_n = triton.cdiv(N, BLOCK_N)
|
||||
if nums_m * nums_n < grid_size:
|
||||
BLOCK_M = triton.cdiv(M, grid_size_div4)
|
||||
BLOCK_N = triton.cdiv(N, 4)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = 128, 256
|
||||
|
||||
# Calculate number of blocks per dimension (ceil division)
|
||||
num_blocks_m = triton.cdiv(M, BLOCK_M)
|
||||
num_blocks_n = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
# Set fixed 1D grid size
|
||||
grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2
|
||||
grid = (grid_size,)
|
||||
|
||||
# Launch kernel
|
||||
|
||||
Reference in New Issue
Block a user