[Model] GLM5 adaptation (#6642)
### What this PR does / why we need it?
GLM5 adaptation
1. use torch_npu.npu_lightning_indexer for GLM5
2. forbid eagle proposer when fullgraph mode is enabled because of bugs
3. add quatization config for GLM5
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM main:
978a37c823
---------
Signed-off-by: yydyzr <liuyuncong1@huawei.com>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
|
|||||||
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
|
||||||
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH}
|
||||||
|
|
||||||
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;"
|
||||||
SOC_ARG="ascend910b"
|
SOC_ARG="ascend910b"
|
||||||
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||||
# ASCEND910C (A3) series
|
# ASCEND910C (A3) series
|
||||||
@@ -68,7 +68,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
|||||||
|
|
||||||
CUSTOM_OPS_ARRAY=(
|
CUSTOM_OPS_ARRAY=(
|
||||||
"grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
"grouped_matmul_swiglu_quant_weight_nz_tensor_list"
|
||||||
"lightning_indexer"
|
"lightning_indexer_vllm"
|
||||||
"sparse_flash_attention"
|
"sparse_flash_attention"
|
||||||
"dispatch_ffn_combine"
|
"dispatch_ffn_combine"
|
||||||
"dispatch_ffn_combine_bf16"
|
"dispatch_ffn_combine_bf16"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|
||||||
add_ops_compile_options(
|
add_ops_compile_options(
|
||||||
OP_NAME LightningIndexer
|
OP_NAME LightningIndexerVllm
|
||||||
OPTIONS --cce-auto-sync=off
|
OPTIONS --cce-auto-sync=off
|
||||||
-Wno-deprecated-declarations
|
-Wno-deprecated-declarations
|
||||||
-Werror
|
-Werror
|
||||||
@@ -16,19 +16,19 @@ add_ops_compile_options(
|
|||||||
--op_relocatable_kernel_binary=true
|
--op_relocatable_kernel_binary=true
|
||||||
)
|
)
|
||||||
|
|
||||||
set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE)
|
set(lightning_indexer_vllm_depends transformer/attention/lightning_indexer_vllm PARENT_SCOPE)
|
||||||
|
|
||||||
target_sources(op_host_aclnn PRIVATE
|
target_sources(op_host_aclnn PRIVATE
|
||||||
lightning_indexer_def.cpp
|
lightning_indexer_vllm_def.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_sources(optiling PRIVATE
|
target_sources(optiling PRIVATE
|
||||||
lightning_indexer_tiling.cpp
|
lightning_indexer_vllm_tiling.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT BUILD_OPEN_PROJECT)
|
if (NOT BUILD_OPEN_PROJECT)
|
||||||
target_sources(opmaster_ct PRIVATE
|
target_sources(opmaster_ct PRIVATE
|
||||||
lightning_indexer_tiling.cpp
|
lightning_indexer_vllm_tiling.cpp
|
||||||
)
|
)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
@@ -37,6 +37,6 @@ target_include_directories(optiling PRIVATE
|
|||||||
)
|
)
|
||||||
|
|
||||||
target_sources(opsproto PRIVATE
|
target_sources(opsproto PRIVATE
|
||||||
lightning_indexer_proto.cpp
|
lightning_indexer_vllm_proto.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,9 +16,9 @@
|
|||||||
#include "register/op_def_registry.h"
|
#include "register/op_def_registry.h"
|
||||||
|
|
||||||
namespace ops {
|
namespace ops {
|
||||||
class LightningIndexer : public OpDef {
|
class LightningIndexerVllm : public OpDef {
|
||||||
public:
|
public:
|
||||||
explicit LightningIndexer(const char *name) : OpDef(name)
|
explicit LightningIndexerVllm(const char *name) : OpDef(name)
|
||||||
{
|
{
|
||||||
this->Input("query")
|
this->Input("query")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
@@ -68,5 +68,5 @@ public:
|
|||||||
this->AICore().AddConfig("ascend910_93", aicore_config);
|
this->AICore().AddConfig("ascend910_93", aicore_config);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
OP_ADD(LightningIndexer);
|
OP_ADD(LightningIndexerVllm);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
@@ -90,7 +90,7 @@ static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext
|
|||||||
return GRAPH_SUCCESS;
|
return GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
IMPL_OP_INFERSHAPE(LightningIndexer)
|
IMPL_OP_INFERSHAPE(LightningIndexerVllm)
|
||||||
.InferShape(InferShapeLightningIndexer)
|
.InferShape(InferShapeLightningIndexer)
|
||||||
.InferDataType(InferDataTypeLightningIndexer);
|
.InferDataType(InferDataTypeLightningIndexer);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
* \brief
|
* \brief
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "lightning_indexer_tiling.h"
|
#include "lightning_indexer_vllm_tiling.h"
|
||||||
#include "../op_kernel/lightning_indexer_template_tiling_key.h"
|
#include "../op_kernel/lightning_indexer_template_tiling_key.h"
|
||||||
|
|
||||||
using namespace ge;
|
using namespace ge;
|
||||||
@@ -687,7 +687,7 @@ ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context)
|
|||||||
return liTiling.DoTiling(&liInfo);
|
return liTiling.DoTiling(&liInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
IMPL_OP_OPTILING(LightningIndexer)
|
IMPL_OP_OPTILING(LightningIndexerVllm)
|
||||||
.Tiling(TilingForLightningIndexer)
|
.Tiling(TilingForLightningIndexer)
|
||||||
.TilingParse<LICompileInfo>(TilingPrepareForLightningIndexer);
|
.TilingParse<LICompileInfo>(TilingPrepareForLightningIndexer);
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ TILING_DATA_FIELD_DEF(uint32_t, blockSize)
|
|||||||
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch)
|
||||||
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
TILING_DATA_FIELD_DEF(uint32_t, sparseMode)
|
||||||
END_TILING_DATA_DEF
|
END_TILING_DATA_DEF
|
||||||
REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData)
|
REGISTER_TILING_DATA_CLASS(LightningIndexerVllm, LITilingData)
|
||||||
|
|
||||||
struct LICompileInfo {};
|
struct LICompileInfo {};
|
||||||
|
|
||||||
@@ -212,4 +212,4 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace optiling
|
} // namespace optiling
|
||||||
#endif // LIGHTNING_INDEXER_TILING_H_
|
#endif // LIGHTNING_INDEXER_TILING_H_
|
||||||
@@ -28,7 +28,7 @@
|
|||||||
|
|
||||||
#define ASCENDC_TPL_4_BW 4
|
#define ASCENDC_TPL_4_BW 4
|
||||||
|
|
||||||
ASCENDC_TPL_ARGS_DECL(LightningIndexer,
|
ASCENDC_TPL_ARGS_DECL(LightningIndexerVllm,
|
||||||
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16),
|
ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16),
|
||||||
ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16),
|
ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16),
|
||||||
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1),
|
ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1),
|
||||||
@@ -35,7 +35,7 @@ using namespace LIKernel;
|
|||||||
|
|
||||||
|
|
||||||
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int LAYOUT_T, int K_LAYOUT_T>
|
template <int DT_Q, int DT_K, int DT_OUT, int PAGE_ATTENTION, int LAYOUT_T, int K_LAYOUT_T>
|
||||||
__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
__global__ __aicore__ void lightning_indexer_vllm(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights,
|
||||||
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths,
|
||||||
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
|
__gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices,
|
||||||
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
__gm__ uint8_t *workspace, __gm__ uint8_t *tiling)
|
||||||
@@ -739,7 +739,7 @@ at::Tensor npu_lightning_indexer(
|
|||||||
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
|
char *query_layout_ptr = const_cast<char *>(query_layout_str.c_str());
|
||||||
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
|
char *key_layout_ptr = const_cast<char *>(key_layout_str.c_str());
|
||||||
EXEC_NPU_CMD(
|
EXEC_NPU_CMD(
|
||||||
aclnnLightningIndexer,
|
aclnnLightningIndexerVllm,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
weights,
|
weights,
|
||||||
|
|||||||
@@ -431,6 +431,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.weights_proj = self.indexer.weights_proj
|
self.weights_proj = self.indexer.weights_proj
|
||||||
self.k_norm = self.indexer.k_norm
|
self.k_norm = self.indexer.k_norm
|
||||||
self.cp_size = 1
|
self.cp_size = 1
|
||||||
|
self.is_rope_neox_style = True
|
||||||
|
self.use_torch_npu_lightning_indexer = False
|
||||||
|
if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]:
|
||||||
|
self.is_rope_neox_style = False
|
||||||
|
self.use_torch_npu_lightning_indexer = True
|
||||||
|
|
||||||
self.enable_dsa_cp = enable_dsa_cp()
|
self.enable_dsa_cp = enable_dsa_cp()
|
||||||
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
|
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
|
||||||
@@ -973,7 +978,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
cos = cos.view(-1, self.qk_rope_head_dim)
|
cos = cos.view(-1, self.qk_rope_head_dim)
|
||||||
sin = sin.view(-1, self.qk_rope_head_dim)
|
sin = sin.view(-1, self.qk_rope_head_dim)
|
||||||
q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True)
|
q, k = rope_forward_triton(
|
||||||
|
q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1)
|
k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
@@ -1036,18 +1043,35 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
||||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||||
|
|
||||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
|
||||||
query=q,
|
# So two branches are maintained temporarily.
|
||||||
key=key,
|
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
|
||||||
weights=weights,
|
if self.use_torch_npu_lightning_indexer:
|
||||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
topk_indices, _ = torch_npu.npu_lightning_indexer(
|
||||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
query=q,
|
||||||
block_table=block_table,
|
key=key,
|
||||||
layout_query="TND",
|
weights=weights,
|
||||||
layout_key="PA_BSND",
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||||
sparse_count=2048,
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||||
sparse_mode=3,
|
block_table=block_table,
|
||||||
)
|
layout_query="TND",
|
||||||
|
layout_key="PA_BSND",
|
||||||
|
sparse_count=2048,
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
||||||
|
query=q,
|
||||||
|
key=key,
|
||||||
|
weights=weights,
|
||||||
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||||
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||||
|
block_table=block_table,
|
||||||
|
layout_query="TND",
|
||||||
|
layout_key="PA_BSND",
|
||||||
|
sparse_count=2048,
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
def _init_o_proj_tp_full_params(self):
|
def _init_o_proj_tp_full_params(self):
|
||||||
|
|||||||
@@ -96,6 +96,11 @@ packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
|
|||||||
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
|
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
|
||||||
},
|
},
|
||||||
|
"glm_moe_dsa": {
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||||
|
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
|
||||||
|
},
|
||||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||||
|
|||||||
@@ -36,7 +36,14 @@ class MtpProposer(EagleProposer):
|
|||||||
dummy_compute_logits=lambda hidden_states: None,
|
dummy_compute_logits=lambda hidden_states: None,
|
||||||
is_profile=False,
|
is_profile=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch:
|
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
|
||||||
|
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
|
||||||
|
# TODO: this conditional check should be removed after bug fixing.
|
||||||
|
if (
|
||||||
|
self.pcp_size * self.dcp_size == 1
|
||||||
|
and not self.speculative_config.disable_padded_drafter_batch
|
||||||
|
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
):
|
||||||
super().dummy_run(
|
super().dummy_run(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
with_prefill,
|
with_prefill,
|
||||||
@@ -166,7 +173,14 @@ class MtpProposer(EagleProposer):
|
|||||||
scheduler_output: SchedulerOutput = None,
|
scheduler_output: SchedulerOutput = None,
|
||||||
num_scheduled_tokens: int = 0,
|
num_scheduled_tokens: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch:
|
# Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer.
|
||||||
|
# Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph.
|
||||||
|
# TODO: this conditional check should be removed after bug fixing.
|
||||||
|
if (
|
||||||
|
self.pcp_size * self.dcp_size == 1
|
||||||
|
and not self.speculative_config.disable_padded_drafter_batch
|
||||||
|
and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
):
|
||||||
draft_token_ids = super()._propose(
|
draft_token_ids = super()._propose(
|
||||||
target_token_ids,
|
target_token_ids,
|
||||||
target_positions,
|
target_positions,
|
||||||
|
|||||||
Reference in New Issue
Block a user