[Model] GLM5 adaptation (#6642)
### What this PR does / why we need it?
GLM5 adaptation
1. use torch_npu.npu_lightning_indexer for GLM5
2. forbid eagle proposer when fullgraph mode is enabled because of bugs
3. add quatization config for GLM5
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM main:
978a37c823
---------
Signed-off-by: yydyzr <liuyuncong1@huawei.com>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
||||
SOC_ARG="ascend910b"
|
||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
# ASCEND910C (A3) series
|
||||
@@ -68,7 +68,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
|
||||
CUSTOM_OPS_ARRAY=(
|
||||
"grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||
"lightning_indexer"
|
||||
"lightning_indexer_vllm"
|
||||
"sparse_flash_attention"
|
||||
"dispatch_ffn_combine"
|
||||
"dispatch_ffn_combine_bf16"
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# ======================================================================================================================
|
||||
|
||||
add_ops_compile_options(
|
||||
OP_NAME LightningIndexer
|
||||
OP_NAME LightningIndexerVllm
|
||||
OPTIONS --cce-auto-sync=off
|
||||
-Wno-deprecated-declarations
|
||||
-Werror
|
||||
@@ -16,19 +16,19 @@ add_ops_compile_options(
|
||||
--op_relocatable_kernel_binary=true
|
||||
)
|
||||
|
||||
set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE)
|
||||
set(lightning_indexer_vllm_depends transformer/attention/lightning_indexer_vllm PARENT_SCOPE)
|
||||
|
||||
target_sources(op_host_aclnn PRIVATE
|
||||
lightning_indexer_def.cpp
|
||||
lightning_indexer_vllm_def.cpp
|
||||
)
|
||||
|
||||
target_sources(optiling PRIVATE
|
||||
lightning_indexer_tiling.cpp
|
||||
lightning_indexer_vllm_tiling.cpp
|
||||
)
|
||||
|
||||
if (NOT BUILD_OPEN_PROJECT)
|
||||
target_sources(opmaster_ct PRIVATE
|
||||
lightning_indexer_tiling.cpp
|
||||
lightning_indexer_vllm_tiling.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
@@ -37,6 +37,6 @@ target_include_directories(optiling PRIVATE
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
lightning_indexer_proto.cpp
|
||||
lightning_indexer_vllm_proto.cpp
|
||||
)
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
#include "register/op_def_registry.h"
|
||||
|
||||
namespace ops {
|
||||
class LightningIndexer : public OpDef {
|
||||
class LightningIndexerVllm : public OpDef {
|
||||
public:
|
||||
explicit LightningIndexer(const char *name) : OpDef(name)
|
||||
explicit LightningIndexerVllm(const char *name) : OpDef(name)
|
||||
{
|
||||
this->Input("query")
|
||||
.ParamType(REQUIRED)
|
||||
@@ -68,5 +68,5 @@ public:
|
||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||
}
|
||||
};
|
||||
OP_ADD(LightningIndexer);
|
||||
OP_ADD(LightningIndexerVllm);
|
||||
} // namespace ops
|
||||
@@ -90,7 +90,7 @@ static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP_INFERSHAPE(LightningIndexer)
|
||||
IMPL_OP_INFERSHAPE(LightningIndexerVllm)
|
||||
.InferShape(InferShapeLightningIndexer)
|
||||
.InferDataType(InferDataTypeLightningIndexer);
|
||||
} // namespace ops
|
||||
@@ -13,7 +13,7 @@
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#include "lightning_indexer_tiling.h"
|
||||
#include "lightning_indexer_vllm_tiling.h"
|
||||
#include "../op_kernel/lightning_indexer_template_tiling_key.h"
|
||||
|
||||
using namespace ge;
|
||||
@@ -687,7 +687,7 @@ ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context)
|
||||
return liTiling.DoTiling(&liInfo);
|
||||
}
|
||||
|
||||
IMPL_OP_OPTILING(LightningIndexer)
|
||||
IMPL_OP_OPTILING(LightningIndexerVllm)
|
||||
.Tiling(TilingForLightningIndexer)
|
||||
.TilingParse<LICompileInfo>(TilingPrepareForLightningIndexer);
|
||||
|
||||
@@ -80,7 +80,7 @@ TILING_DATA_FIELD_DEF(uint32_t, blockSize)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||
END_TILING_DATA_DEF
|
||||
REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData)
|
||||
REGISTER_TILING_DATA_CLASS(LightningIndexerVllm, LITilingData)
|
||||
|
||||
struct LICompileInfo {};
|
||||
|
||||
@@ -212,4 +212,4 @@ private:
|
||||
};
|
||||
|
||||
} // namespace optiling
|
||||
#endif // LIGHTNING_INDEXER_TILING_H_
|
||||
#endif // LIGHTNING_INDEXER_TILING_H_
|
||||
@@ -28,7 +28,7 @@
|
||||
|
||||
#define ASCENDC_TPL_4_BW 4
|
||||
|
||||
ASCENDC_TPL_ARGS_DECL(LightningIndexer,
|
||||
ASCENDC_TPL_ARGS_DECL(LightningIndexerVllm,
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16),
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16),
|
||||
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1),
|
||||
@@ -35,7 +35,7 @@ using namespace LIKernel;
|
||||
|
||||
|
||||
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int LAYOUT_T, int K_LAYOUT_T>
|
||||
__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__global__ __aicore__ void lightning_indexer_vllm(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
||||
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
|
||||
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
||||
@@ -739,7 +739,7 @@ at::Tensor npu_lightning_indexer(
|
||||
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
|
||||
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
|
||||
EXEC_NPU_CMD(
|
||||
aclnnLightningIndexer,
|
||||
aclnnLightningIndexerVllm,
|
||||
query,
|
||||
key,
|
||||
weights,
|
||||
|
||||
Reference in New Issue
Block a user