59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
import os
|
|
|
|
# support Q,KV Gen with TP
|
|
USE_PARALLEL_Q_KV_GEN = True
|
|
|
|
# support Merge Q,KV Gen, Q,QR weights Merge
|
|
USE_MERGE_Q_KV_GEN_AND_Q_QR = True
|
|
|
|
# Support FP8 Weights for WQ,QR
|
|
W_Q_W_QR_WUV_WUK_USE_FP8 = True
|
|
|
|
# fused prefill
|
|
USE_FUSED_PREFILL = True
|
|
|
|
# fused prefill stage1
|
|
USE_FUSED_PREFILL_STAGE1 = True
|
|
|
|
# All Request Seq Lens
|
|
DO_SEQ_LENS = 0
|
|
def update_seqence_length(seq_num):
|
|
global DO_SEQ_LENS
|
|
DO_SEQ_LENS = seq_num
|
|
|
|
USE_DS3_SAMPLER = int(os.getenv("USE_DS3_SAMPLER", 1))
|
|
USE_DS3_SAMPLER_OP = int(os.getenv("USE_DS3_SAMPLER_OP", 1))
|
|
|
|
# cut prefill seq len
|
|
CUT_PREFILL_SEQ_LEN = int(os.getenv("CUT_PREFILL_SEQ_LEN", -1))
|
|
|
|
# llm max prefill seq len
|
|
LLM_MAX_PREFILL_SEQ_LEN = int(os.getenv("LLM_MAX_PREFILL_SEQ_LEN", 56 * 1024))
|
|
|
|
# All Fused Decode, default is cpu loop
|
|
USE_DECODER_LAYER_FUSE_MODE = int(os.getenv("USE_DECODER_LAYER_FUSE_MODE", 1))
|
|
|
|
# Fused all layers, use cmcu loop
|
|
FUSE_ALL_DECODER_LAYERS = int(os.getenv("FUSE_ALL_DECODER_LAYERS", 1))
|
|
|
|
# where to use flash attention (default: 1)
|
|
USE_FLASH_ATTENTION = int(os.getenv("USE_FLASH_ATTENTION", 1))
|
|
|
|
# transpose gptq weight KN => NK
|
|
TRANSPOSE_GPTQ_WEIGHT = True
|
|
|
|
# qwen fused attention
|
|
USE_FUSED_QWEN_ATTENTION = int(os.getenv("USE_FUSED_QWEN_ATTENTION", 1))
|
|
|
|
# support MTP eh_proj with TP
|
|
USE_PARALLEL_MTP_EH_PROJ = int(os.getenv("USE_PARALLEL_MTP_EH_PROJ", 1))
|
|
|
|
# kv_cache group size
|
|
BLOCK_GROUP_SIZE = int(os.getenv("BLOCK_GROUP_SIZE", 8192))
|
|
|
|
# bert fused attention
|
|
USE_FUSED_BERT_ATTENTION = int(os.getenv("USE_FUSED_BERT_ATTENTION", 1))
|
|
|
|
# fused mlp vision
|
|
USE_FUSED_MLP_VISION = int(os.getenv("USE_FUSED_MLP_VISION", 1))
|