# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import Tuple from math import ceil _MLU_MAX_GRID_SIZE = 65536 def adjust_kernel_block_size( m: int, block_m: int, n: int, block_n: int ) -> Tuple[int, int]: """Adjust block size to meet mlu triton grid restrictions. Calculation of the max block size in candidates list: LLama3.1-8b-tp1 max n is 14336 LLama3.1-70b-tp4 max n is 7168 LLama3.1-405b-tp8 max n is 6656 when n is 14336, the max sequence length of block size 256 can be floor(65536 / ceil(14336 / 256)) * 256 = 299520. """ candidates_list = [16, 32, 64, 96, 128, 192, 256] candidates_list_len = len(candidates_list) m_idx = 1 n_idx = 0 if block_n == 16 else 1 while m_idx < candidates_list_len and n_idx < candidates_list_len: block_m = candidates_list[m_idx] block_n = candidates_list[n_idx] if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE: break if m_idx < candidates_list_len: m_idx += 1 if n_idx < candidates_list_len: n_idx += 1 if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE: raise ValueError(f"the max seq len {m} is too long for lora triton kernel") return block_m, block_n