From 0ead5e8681c4f13c7746d2611268e263f1514ec8 Mon Sep 17 00:00:00 2001 From: Zhijun Chen Date: Wed, 4 Feb 2026 21:36:26 +0800 Subject: [PATCH] perf: adaptive block size selection in linear_persistent kernel (#6537) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? **Optimization:** Replaces fixed block sizes (128x128x128) in `linear_persistent_kernel` with adaptive selection logic that considers: - Matrix dimensions (M, N, K) - Device NPU vector core count - Data type (float32 vs others) **Why:** Fixed block sizes lead to suboptimal hardware utilization across different matrix shapes. Adaptive sizing maximizes occupancy and memory efficiency for varied workload patterns, improving throughput for batch-invariant linear operations in LLM inference. **Details:** - Small matrices (M < 256): Size-proportional allocation - Medium matrices (256 ≤ M < 1024): Balanced distribution based on grid capacity - Large matrices (M ≥ 1024): Optimized for dominant dimension ### Does this PR introduce _any_ user-facing change? No. This is a performance optimization. The API and numerical results remain unchanged; only kernel execution efficiency improves. ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: DDCHY <843049740@qq.com> Signed-off-by: zjchenn Co-authored-by: DDCHY <843049740@qq.com> --- .../ops/triton/batch_invariant/matmul.py | 48 ++++++++++++++++++- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/triton/batch_invariant/matmul.py b/vllm_ascend/ops/triton/batch_invariant/matmul.py index 175d002e..e5606a27 100644 --- a/vllm_ascend/ops/triton/batch_invariant/matmul.py +++ b/vllm_ascend/ops/triton/batch_invariant/matmul.py @@ -269,15 +269,59 @@ def linear_persistent(x, y): # Allocate output tensor (same data type as x) output = torch.zeros((M, N), dtype=x.dtype, device=x.device) + grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2 + # Define block sizes (can be adjusted based on hardware) - BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + BLOCK_K = 256 + if x.dtype == torch.float32: + BLOCK_K = BLOCK_K // 2 + grid_size_div4 = grid_size // 4 + if M == 0 or N == 0: + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 256 + elif M < 256: + BLOCK_M = M + if grid_size * 128 <= N: + if M <= 128: + BLOCK_N = 256 + else: + BLOCK_N = 128 + elif grid_size * 32 >= N: + if M > N: + BLOCK_M = triton.cdiv(M, grid_size_div4) + BLOCK_N = triton.cdiv(N, 4) + else: + BLOCK_M = triton.cdiv(M, 4) + BLOCK_N = triton.cdiv(N, grid_size_div4) + else: + BLOCK_N = triton.next_power_of_2(triton.cdiv(N, grid_size)) + elif M >= 256 and M < 1024: + if M < N: + BLOCK_M = 256 + nums_m = triton.cdiv(M, BLOCK_M) + nums_n = grid_size // nums_m + if 128 * nums_n <= N: + BLOCK_N = 128 + else: + BLOCK_N = min(triton.next_power_of_2(triton.cdiv(N, nums_n)), 128) + else: + BLOCK_M = min(triton.cdiv(M, grid_size_div4), 256) + BLOCK_N = min(triton.cdiv(N, 4), 128) + else: + if M > N: + BLOCK_M, BLOCK_N = 256, 128 + nums_m = triton.cdiv(M, BLOCK_M) + nums_n = triton.cdiv(N, BLOCK_N) + if nums_m * nums_n < grid_size: + BLOCK_M = triton.cdiv(M, grid_size_div4) + BLOCK_N = triton.cdiv(N, 4) + else: + BLOCK_M, BLOCK_N = 128, 256 # Calculate number of blocks per dimension (ceil division) num_blocks_m = triton.cdiv(M, BLOCK_M) num_blocks_n = triton.cdiv(N, BLOCK_N) # Set fixed 1D grid size - grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2 grid = (grid_size,) # Launch kernel