184 lines
7.7 KiB
Python
184 lines
7.7 KiB
Python
import functools
|
|
import json
|
|
import torch
|
|
import os
|
|
from enum import Enum
|
|
from typing import Any, Dict, Optional, Tuple
|
|
import bisect
|
|
from vllm.logger import init_logger
|
|
logger = init_logger(__name__)
|
|
|
|
class KERNLE_KINDS(Enum):
|
|
v1_2stages = 0
|
|
v1_2stages_tc = 1
|
|
v2 = 2
|
|
v2_tc = 3
|
|
TOTAL_KIND = 4
|
|
|
|
class BestConfig():
|
|
def __init__(self):
|
|
self.batch_size = 0
|
|
self.seq_len = 0
|
|
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
|
|
self.BLOCK_N = 0
|
|
self.BLOCK_DIM = 0
|
|
# self.BLOCK_SEQ = 0
|
|
# self.SPLIT_K = 0
|
|
self.num_stages = 0
|
|
self.num_warps = 0
|
|
self.NUM_KV_SPLITS = 0
|
|
self.BLOCK_N_2 = 0
|
|
self.num_stages_2 = 0
|
|
self.num_warps_2 = 0
|
|
self.best_us = 0
|
|
self.decode_fwd_stage1 = None
|
|
self.decode_fwd_stage2 = None
|
|
|
|
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
|
|
if cache_dtype == "default":
|
|
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
|
|
|
|
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
|
if "K100_AI" in device_name:
|
|
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
|
|
elif "BW" in device_name:
|
|
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
|
|
else:
|
|
raise ValueError(f"Unsurpport device name: {device_name}")
|
|
|
|
|
|
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
|
|
|
|
# First look up if an optimized configuration is available in the configs
|
|
# directory
|
|
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
|
|
|
|
config_file_path = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
|
)
|
|
if os.path.exists(config_file_path):
|
|
with open(config_file_path) as f:
|
|
# logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
|
|
# If a configuration has been found, return it
|
|
return json.load(f)
|
|
else:
|
|
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
|
|
|
|
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
|
|
config_file_path = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
|
)
|
|
if os.path.exists(config_file_path):
|
|
with open(config_file_path) as f:
|
|
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
|
|
# If a configuration has been found, return it
|
|
return json.load(f)
|
|
else:
|
|
raise ValueError("Please surpport default config can match 16 1 576 512")
|
|
|
|
# If no optimized configuration is available, we will use the default
|
|
# configuration
|
|
return None
|
|
|
|
|
|
def get_config_map(attention_configs):
|
|
ret_map = {}
|
|
for bs in attention_configs.keys():
|
|
int_bs = int(bs)
|
|
seq_map = {}
|
|
seq_configs = attention_configs[bs]
|
|
ret_map[int_bs] = seq_map
|
|
for seq_len in seq_configs.keys():
|
|
int_seq_len = int(seq_len)
|
|
kind_config = seq_configs[seq_len]
|
|
configs = BestConfig()
|
|
# configs.batch_size = int_bs
|
|
# configs.seq_len = int_seq_len
|
|
configs.best_us = kind_config['best_us']
|
|
seq_map[int_seq_len] = configs
|
|
if kind_config['kernel_kind'] == 'v1_2stages':
|
|
best_config = kind_config['best_config']
|
|
stage1 = best_config['stage1']
|
|
stage2 = best_config['stage2']
|
|
configs.kernel_kind = KERNLE_KINDS.v1_2stages
|
|
# configs.SPLIT_K = stage1['SPLIT_K']
|
|
configs.BLOCK_N = stage1['BLOCK_N']
|
|
configs.num_stages = stage1['num_stages']
|
|
configs.num_warps = stage1['num_warps']
|
|
configs.BLOCK_N_2 = stage2['BLOCK_N']
|
|
configs.num_stages_2 = stage2['num_stages']
|
|
configs.num_warps_2 = stage2['num_warps']
|
|
elif kind_config['kernel_kind'] == 'v1_2stages_tc':
|
|
best_config = kind_config['best_config']
|
|
stage1 = best_config['stage1']
|
|
stage2 = best_config['stage2']
|
|
configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc
|
|
# configs.SPLIT_K = stage1['SPLIT_K']
|
|
configs.BLOCK_N = stage1['BLOCK_N']
|
|
configs.num_stages = stage1['num_stages']
|
|
configs.num_warps = stage1['num_warps']
|
|
configs.BLOCK_N_2 = stage2['BLOCK_N']
|
|
configs.num_stages_2 = stage2['num_stages']
|
|
configs.num_warps_2 = stage2['num_warps']
|
|
elif kind_config['kernel_kind'] == 'v2':
|
|
best_config = kind_config['best_config']
|
|
stage1 = best_config['stage1']
|
|
stage2 = best_config['stage2']
|
|
configs.kernel_kind = KERNLE_KINDS.v2
|
|
# if 'BLOCK_SEQ' in stage1:
|
|
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
|
|
# else:
|
|
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
|
|
configs.BLOCK_N = stage1['BLOCK_N']
|
|
configs.num_stages = stage1['num_stages']
|
|
configs.num_warps = stage1['num_warps']
|
|
configs.num_stages_2 = stage2['num_stages']
|
|
configs.num_warps_2 = stage2['num_warps']
|
|
elif kind_config['kernel_kind'] == 'v2_tc':
|
|
best_config = kind_config['best_config']
|
|
stage1 = best_config['stage1']
|
|
stage2 = best_config['stage2']
|
|
configs.kernel_kind = KERNLE_KINDS.v2_tc
|
|
# if 'BLOCK_SEQ' in stage1:
|
|
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
|
|
# else:
|
|
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
|
|
configs.BLOCK_N = stage1['BLOCK_N']
|
|
configs.BLOCK_DIM = stage1['BLOCK_DIM']
|
|
configs.num_stages = stage1['num_stages']
|
|
configs.num_warps = stage1['num_warps']
|
|
configs.num_stages_2 = stage2['num_stages']
|
|
configs.num_warps_2 = stage2['num_warps']
|
|
return ret_map
|
|
|
|
|
|
@functools.lru_cache
|
|
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
|
|
attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype)
|
|
return get_config_map(attention_configs)
|
|
|
|
|
|
def get_closest_key(dic_keys, target_key):
|
|
keys = list(dic_keys)
|
|
idx = bisect.bisect_left(keys, target_key)
|
|
if idx == 0:
|
|
return keys[0]
|
|
if idx == len(keys):
|
|
return keys[-1]
|
|
left_key = keys[idx - 1]
|
|
right_key = keys[idx]
|
|
if target_key - left_key <= right_key - target_key:
|
|
return left_key
|
|
else:
|
|
return right_key
|
|
|
|
def get_nearest_config(bs_key, mean_kv_seqlen_key, config):
|
|
closest_bs_key = get_closest_key(config.keys(), bs_key)
|
|
closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key)
|
|
return config[closest_bs_key][closest_mean_kv_seqlen_key]
|
|
|
|
def get_config(bs_key, mean_kv_seqlen_key, config):
|
|
if bs_key in config and mean_kv_seqlen_key in config[bs_key]:
|
|
return config[bs_key][mean_kv_seqlen_key]
|
|
else:
|
|
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db") |