Files
2026-04-24 09:58:03 +08:00

42 lines
1.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
from math import ceil
_MLU_MAX_GRID_SIZE = 65536
def adjust_kernel_block_size(
m: int,
block_m: int,
n: int,
block_n: int
) -> Tuple[int, int]:
"""Adjust block size to meet mlu triton grid restrictions.
Calculation of the max block size in candidates list:
LLama3.1-8b-tp1 max n is 14336
LLama3.1-70b-tp4 max n is 7168
LLama3.1-405b-tp8 max n is 6656
when n is 14336, the max sequence length of block size 256 can be
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
"""
candidates_list = [16, 32, 64, 96, 128, 192, 256]
candidates_list_len = len(candidates_list)
m_idx = 1
n_idx = 0 if block_n == 16 else 1
while m_idx < candidates_list_len and n_idx < candidates_list_len:
block_m = candidates_list[m_idx]
block_n = candidates_list[n_idx]
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
break
if m_idx < candidates_list_len:
m_idx += 1
if n_idx < candidates_list_len:
n_idx += 1
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
return block_m, block_n