import functools import json import torch import os from enum import Enum from typing import Any, Dict, Optional, Tuple import bisect from vllm.logger import init_logger logger = init_logger(__name__) class KERNLE_KINDS(Enum): v1_2stages = 0 v1_2stages_tc = 1 v2 = 2 v2_tc = 3 TOTAL_KIND = 4 class BestConfig(): def __init__(self): self.batch_size = 0 self.seq_len = 0 self.kernel_kind = KERNLE_KINDS.TOTAL_KIND self.BLOCK_N = 0 self.BLOCK_DIM = 0 # self.BLOCK_SEQ = 0 # self.SPLIT_K = 0 self.num_stages = 0 self.num_warps = 0 self.NUM_KV_SPLITS = 0 self.BLOCK_N_2 = 0 self.num_stages_2 = 0 self.num_warps_2 = 0 self.best_us = 0 self.decode_fwd_stage1 = None self.decode_fwd_stage2 = None def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str: if cache_dtype == "default": return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json" device_name = torch.cuda.get_device_name().replace(" ", "_") if "K100_AI" in device_name: return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json" elif "BW" in device_name: return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json" else: raise ValueError(f"Unsurpport device name: {device_name}") def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]: # First look up if an optimized configuration is available in the configs # directory json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: # logger.info("Using decode attention configuration from %s for attention layer.", config_file_path) # If a configuration has been found, return it return json.load(f) else: logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path) json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default") config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path) # If a configuration has been found, return it return json.load(f) else: raise ValueError("Please surpport default config can match 16 1 576 512") # If no optimized configuration is available, we will use the default # configuration return None def get_config_map(attention_configs): ret_map = {} for bs in attention_configs.keys(): int_bs = int(bs) seq_map = {} seq_configs = attention_configs[bs] ret_map[int_bs] = seq_map for seq_len in seq_configs.keys(): int_seq_len = int(seq_len) kind_config = seq_configs[seq_len] configs = BestConfig() # configs.batch_size = int_bs # configs.seq_len = int_seq_len configs.best_us = kind_config['best_us'] seq_map[int_seq_len] = configs if kind_config['kernel_kind'] == 'v1_2stages': best_config = kind_config['best_config'] stage1 = best_config['stage1'] stage2 = best_config['stage2'] configs.kernel_kind = KERNLE_KINDS.v1_2stages # configs.SPLIT_K = stage1['SPLIT_K'] configs.BLOCK_N = stage1['BLOCK_N'] configs.num_stages = stage1['num_stages'] configs.num_warps = stage1['num_warps'] configs.BLOCK_N_2 = stage2['BLOCK_N'] configs.num_stages_2 = stage2['num_stages'] configs.num_warps_2 = stage2['num_warps'] elif kind_config['kernel_kind'] == 'v1_2stages_tc': best_config = kind_config['best_config'] stage1 = best_config['stage1'] stage2 = best_config['stage2'] configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc # configs.SPLIT_K = stage1['SPLIT_K'] configs.BLOCK_N = stage1['BLOCK_N'] configs.num_stages = stage1['num_stages'] configs.num_warps = stage1['num_warps'] configs.BLOCK_N_2 = stage2['BLOCK_N'] configs.num_stages_2 = stage2['num_stages'] configs.num_warps_2 = stage2['num_warps'] elif kind_config['kernel_kind'] == 'v2': best_config = kind_config['best_config'] stage1 = best_config['stage1'] stage2 = best_config['stage2'] configs.kernel_kind = KERNLE_KINDS.v2 # if 'BLOCK_SEQ' in stage1: # configs.BLOCK_SEQ = stage1['BLOCK_SEQ'] # else: # configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS'] configs.BLOCK_N = stage1['BLOCK_N'] configs.num_stages = stage1['num_stages'] configs.num_warps = stage1['num_warps'] configs.num_stages_2 = stage2['num_stages'] configs.num_warps_2 = stage2['num_warps'] elif kind_config['kernel_kind'] == 'v2_tc': best_config = kind_config['best_config'] stage1 = best_config['stage1'] stage2 = best_config['stage2'] configs.kernel_kind = KERNLE_KINDS.v2_tc # if 'BLOCK_SEQ' in stage1: # configs.BLOCK_SEQ = stage1['BLOCK_SEQ'] # else: # configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS'] configs.BLOCK_N = stage1['BLOCK_N'] configs.BLOCK_DIM = stage1['BLOCK_DIM'] configs.num_stages = stage1['num_stages'] configs.num_warps = stage1['num_warps'] configs.num_stages_2 = stage2['num_stages'] configs.num_warps_2 = stage2['num_warps'] return ret_map @functools.lru_cache def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]: attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype) return get_config_map(attention_configs) def get_closest_key(dic_keys, target_key): keys = list(dic_keys) idx = bisect.bisect_left(keys, target_key) if idx == 0: return keys[0] if idx == len(keys): return keys[-1] left_key = keys[idx - 1] right_key = keys[idx] if target_key - left_key <= right_key - target_key: return left_key else: return right_key def get_nearest_config(bs_key, mean_kv_seqlen_key, config): closest_bs_key = get_closest_key(config.keys(), bs_key) closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key) return config[closest_bs_key][closest_mean_kv_seqlen_key] def get_config(bs_key, mean_kv_seqlen_key, config): if bs_key in config and mean_kv_seqlen_key in config[bs_key]: return config[bs_key][mean_kv_seqlen_key] else: raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")