diff --git a/vllm_ascend/ops/triton/batch_invariant/matmul.py b/vllm_ascend/ops/triton/batch_invariant/matmul.py index 175d002e..e5606a27 100644 --- a/vllm_ascend/ops/triton/batch_invariant/matmul.py +++ b/vllm_ascend/ops/triton/batch_invariant/matmul.py @@ -269,15 +269,59 @@ def linear_persistent(x, y): # Allocate output tensor (same data type as x) output = torch.zeros((M, N), dtype=x.dtype, device=x.device) + grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2 + # Define block sizes (can be adjusted based on hardware) - BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + BLOCK_K = 256 + if x.dtype == torch.float32: + BLOCK_K = BLOCK_K // 2 + grid_size_div4 = grid_size // 4 + if M == 0 or N == 0: + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 256 + elif M < 256: + BLOCK_M = M + if grid_size * 128 <= N: + if M <= 128: + BLOCK_N = 256 + else: + BLOCK_N = 128 + elif grid_size * 32 >= N: + if M > N: + BLOCK_M = triton.cdiv(M, grid_size_div4) + BLOCK_N = triton.cdiv(N, 4) + else: + BLOCK_M = triton.cdiv(M, 4) + BLOCK_N = triton.cdiv(N, grid_size_div4) + else: + BLOCK_N = triton.next_power_of_2(triton.cdiv(N, grid_size)) + elif M >= 256 and M < 1024: + if M < N: + BLOCK_M = 256 + nums_m = triton.cdiv(M, BLOCK_M) + nums_n = grid_size // nums_m + if 128 * nums_n <= N: + BLOCK_N = 128 + else: + BLOCK_N = min(triton.next_power_of_2(triton.cdiv(N, nums_n)), 128) + else: + BLOCK_M = min(triton.cdiv(M, grid_size_div4), 256) + BLOCK_N = min(triton.cdiv(N, 4), 128) + else: + if M > N: + BLOCK_M, BLOCK_N = 256, 128 + nums_m = triton.cdiv(M, BLOCK_M) + nums_n = triton.cdiv(N, BLOCK_N) + if nums_m * nums_n < grid_size: + BLOCK_M = triton.cdiv(M, grid_size_div4) + BLOCK_N = triton.cdiv(N, 4) + else: + BLOCK_M, BLOCK_N = 128, 256 # Calculate number of blocks per dimension (ceil division) num_blocks_m = triton.cdiv(M, BLOCK_M) num_blocks_n = triton.cdiv(N, BLOCK_N) # Set fixed 1D grid size - grid_size = driver.active.utils.get_device_properties(torch.npu.current_device())["num_vectorcore"] // 2 grid = (grid_size,) # Launch kernel