42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import Tuple
|
|
from math import ceil
|
|
|
|
_MLU_MAX_GRID_SIZE = 65536
|
|
|
|
def adjust_kernel_block_size(
|
|
m: int,
|
|
block_m: int,
|
|
n: int,
|
|
block_n: int
|
|
) -> Tuple[int, int]:
|
|
"""Adjust block size to meet mlu triton grid restrictions.
|
|
|
|
Calculation of the max block size in candidates list:
|
|
|
|
LLama3.1-8b-tp1 max n is 14336
|
|
LLama3.1-70b-tp4 max n is 7168
|
|
LLama3.1-405b-tp8 max n is 6656
|
|
|
|
when n is 14336, the max sequence length of block size 256 can be
|
|
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
|
|
"""
|
|
candidates_list = [16, 32, 64, 96, 128, 192, 256]
|
|
candidates_list_len = len(candidates_list)
|
|
m_idx = 1
|
|
n_idx = 0 if block_n == 16 else 1
|
|
while m_idx < candidates_list_len and n_idx < candidates_list_len:
|
|
block_m = candidates_list[m_idx]
|
|
block_n = candidates_list[n_idx]
|
|
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
|
|
break
|
|
if m_idx < candidates_list_len:
|
|
m_idx += 1
|
|
if n_idx < candidates_list_len:
|
|
n_idx += 1
|
|
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
|
|
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
|
|
return block_m, block_n
|