import os # support Q,KV Gen with TP USE_PARALLEL_Q_KV_GEN = True # support Merge Q,KV Gen, Q,QR weights Merge USE_MERGE_Q_KV_GEN_AND_Q_QR = True # Support FP8 Weights for WQ,QR W_Q_W_QR_WUV_WUK_USE_FP8 = True # fused prefill USE_FUSED_PREFILL = True # fused prefill stage1 USE_FUSED_PREFILL_STAGE1 = True # All Request Seq Lens DO_SEQ_LENS = 0 def update_seqence_length(seq_num): global DO_SEQ_LENS DO_SEQ_LENS = seq_num USE_DS3_SAMPLER = int(os.getenv("USE_DS3_SAMPLER", 1)) USE_DS3_SAMPLER_OP = int(os.getenv("USE_DS3_SAMPLER_OP", 1)) # cut prefill seq len CUT_PREFILL_SEQ_LEN = int(os.getenv("CUT_PREFILL_SEQ_LEN", -1)) # llm max prefill seq len LLM_MAX_PREFILL_SEQ_LEN = int(os.getenv("LLM_MAX_PREFILL_SEQ_LEN", 56 * 1024)) # All Fused Decode, default is cpu loop USE_DECODER_LAYER_FUSE_MODE = int(os.getenv("USE_DECODER_LAYER_FUSE_MODE", 1)) # Fused all layers, use cmcu loop FUSE_ALL_DECODER_LAYERS = int(os.getenv("FUSE_ALL_DECODER_LAYERS", 1)) # where to use flash attention (default: 1) USE_FLASH_ATTENTION = int(os.getenv("USE_FLASH_ATTENTION", 1)) # transpose gptq weight KN => NK TRANSPOSE_GPTQ_WEIGHT = True # qwen fused attention USE_FUSED_QWEN_ATTENTION = int(os.getenv("USE_FUSED_QWEN_ATTENTION", 1)) # support MTP eh_proj with TP USE_PARALLEL_MTP_EH_PROJ = int(os.getenv("USE_PARALLEL_MTP_EH_PROJ", 1)) # kv_cache group size BLOCK_GROUP_SIZE = int(os.getenv("BLOCK_GROUP_SIZE", 8192)) # bert fused attention USE_FUSED_BERT_ATTENTION = int(os.getenv("USE_FUSED_BERT_ATTENTION", 1)) # fused mlp vision USE_FUSED_MLP_VISION = int(os.getenv("USE_FUSED_MLP_VISION", 1))