diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index fc4c53d0..929fdc3e 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer_vllm;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom;moe_gating_top_k;add_rms_norm_bias;apply_top_k_top_p_custom;transpose_kv_cache_by_block;" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -68,7 +68,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then CUSTOM_OPS_ARRAY=( "grouped_matmul_swiglu_quant_weight_nz_tensor_list" - "lightning_indexer" + "lightning_indexer_vllm" "sparse_flash_attention" "dispatch_ffn_combine" "dispatch_ffn_combine_bf16" diff --git a/csrc/lightning_indexer/op_host/CMakeLists.txt b/csrc/lightning_indexer_vllm/op_host/CMakeLists.txt similarity index 81% rename from csrc/lightning_indexer/op_host/CMakeLists.txt rename to csrc/lightning_indexer_vllm/op_host/CMakeLists.txt index 7922ba8e..29671371 100644 --- a/csrc/lightning_indexer/op_host/CMakeLists.txt +++ b/csrc/lightning_indexer_vllm/op_host/CMakeLists.txt @@ -8,7 +8,7 @@ # ====================================================================================================================== add_ops_compile_options( - OP_NAME LightningIndexer + OP_NAME LightningIndexerVllm OPTIONS --cce-auto-sync=off -Wno-deprecated-declarations -Werror @@ -16,19 +16,19 @@ add_ops_compile_options( --op_relocatable_kernel_binary=true ) -set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE) +set(lightning_indexer_vllm_depends transformer/attention/lightning_indexer_vllm PARENT_SCOPE) target_sources(op_host_aclnn PRIVATE - lightning_indexer_def.cpp + lightning_indexer_vllm_def.cpp ) target_sources(optiling PRIVATE - lightning_indexer_tiling.cpp + lightning_indexer_vllm_tiling.cpp ) if (NOT BUILD_OPEN_PROJECT) target_sources(opmaster_ct PRIVATE - lightning_indexer_tiling.cpp + lightning_indexer_vllm_tiling.cpp ) endif () @@ -37,6 +37,6 @@ target_include_directories(optiling PRIVATE ) target_sources(opsproto PRIVATE - lightning_indexer_proto.cpp + lightning_indexer_vllm_proto.cpp ) diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_def.cpp similarity index 95% rename from csrc/lightning_indexer/op_host/lightning_indexer_def.cpp rename to csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_def.cpp index 95f97a34..3df7dd64 100644 --- a/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp +++ b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_def.cpp @@ -16,9 +16,9 @@ #include "register/op_def_registry.h" namespace ops { -class LightningIndexer : public OpDef { +class LightningIndexerVllm : public OpDef { public: - explicit LightningIndexer(const char *name) : OpDef(name) + explicit LightningIndexerVllm(const char *name) : OpDef(name) { this->Input("query") .ParamType(REQUIRED) @@ -68,5 +68,5 @@ public: this->AICore().AddConfig("ascend910_93", aicore_config); } }; -OP_ADD(LightningIndexer); +OP_ADD(LightningIndexerVllm); } // namespace ops \ No newline at end of file diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_proto.cpp similarity index 99% rename from csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp rename to csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_proto.cpp index cc1a793e..8761d9cb 100644 --- a/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp +++ b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_proto.cpp @@ -90,7 +90,7 @@ static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext return GRAPH_SUCCESS; } -IMPL_OP_INFERSHAPE(LightningIndexer) +IMPL_OP_INFERSHAPE(LightningIndexerVllm) .InferShape(InferShapeLightningIndexer) .InferDataType(InferDataTypeLightningIndexer); } // namespace ops diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_tiling.cpp similarity index 99% rename from csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp rename to csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_tiling.cpp index ae49996b..1355ff5b 100644 --- a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp +++ b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_tiling.cpp @@ -13,7 +13,7 @@ * \brief */ -#include "lightning_indexer_tiling.h" +#include "lightning_indexer_vllm_tiling.h" #include "../op_kernel/lightning_indexer_template_tiling_key.h" using namespace ge; @@ -687,7 +687,7 @@ ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context) return liTiling.DoTiling(&liInfo); } -IMPL_OP_OPTILING(LightningIndexer) +IMPL_OP_OPTILING(LightningIndexerVllm) .Tiling(TilingForLightningIndexer) .TilingParse(TilingPrepareForLightningIndexer); diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_tiling.h similarity index 98% rename from csrc/lightning_indexer/op_host/lightning_indexer_tiling.h rename to csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_tiling.h index fb7ce43d..bf5156a9 100644 --- a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h +++ b/csrc/lightning_indexer_vllm/op_host/lightning_indexer_vllm_tiling.h @@ -80,7 +80,7 @@ TILING_DATA_FIELD_DEF(uint32_t, blockSize) TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) TILING_DATA_FIELD_DEF(uint32_t, sparseMode) END_TILING_DATA_DEF -REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData) +REGISTER_TILING_DATA_CLASS(LightningIndexerVllm, LITilingData) struct LICompileInfo {}; @@ -212,4 +212,4 @@ private: }; } // namespace optiling -#endif // LIGHTNING_INDEXER_TILING_H_ \ No newline at end of file +#endif // LIGHTNING_INDEXER_TILING_H_ diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_common.h similarity index 100% rename from csrc/lightning_indexer/op_kernel/lightning_indexer_common.h rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_common.h diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_kernel.h similarity index 100% rename from csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_kernel.h diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_service_cube.h similarity index 100% rename from csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_service_cube.h diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_service_vector.h similarity index 100% rename from csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_service_vector.h diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_template_tiling_key.h similarity index 98% rename from csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_template_tiling_key.h index a4ce580a..b6f0a484 100644 --- a/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h +++ b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_template_tiling_key.h @@ -28,7 +28,7 @@ #define ASCENDC_TPL_4_BW 4 -ASCENDC_TPL_ARGS_DECL(LightningIndexer, +ASCENDC_TPL_ARGS_DECL(LightningIndexerVllm, ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16), ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16), ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1), diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_vector.h similarity index 100% rename from csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_vector.h diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_vllm.cpp similarity index 96% rename from csrc/lightning_indexer/op_kernel/lightning_indexer.cpp rename to csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_vllm.cpp index fefa72e6..8f2981a2 100644 --- a/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp +++ b/csrc/lightning_indexer_vllm/op_kernel/lightning_indexer_vllm.cpp @@ -35,7 +35,7 @@ using namespace LIKernel; template -__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, +__global__ __aicore__ void lightning_indexer_vllm(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 820fc5f2..c43e1a9c 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -739,7 +739,7 @@ at::Tensor npu_lightning_indexer( char *query_layout_ptr = const_cast(query_layout_str.c_str()); char *key_layout_ptr = const_cast(key_layout_str.c_str()); EXEC_NPU_CMD( - aclnnLightningIndexer, + aclnnLightningIndexerVllm, query, key, weights, diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 26984947..c3f30e2c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -431,6 +431,11 @@ class AscendSFAImpl(MLAAttentionImpl): self.weights_proj = self.indexer.weights_proj self.k_norm = self.indexer.k_norm self.cp_size = 1 + self.is_rope_neox_style = True + self.use_torch_npu_lightning_indexer = False + if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]: + self.is_rope_neox_style = False + self.use_torch_npu_lightning_indexer = True self.enable_dsa_cp = enable_dsa_cp() self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard() @@ -973,7 +978,9 @@ class AscendSFAImpl(MLAAttentionImpl): cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) - q, k = rope_forward_triton(q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=True) + q, k = rope_forward_triton( + q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + ) else: k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) @@ -1036,18 +1043,35 @@ class AscendSFAImpl(MLAAttentionImpl): key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) block_table = attn_metadata.sfa_cp_metadata.block_table_cp - topk_indices = torch.ops._C_ascend.npu_lightning_indexer( - query=q, - key=key, - weights=weights, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, - layout_query="TND", - layout_key="PA_BSND", - sparse_count=2048, - sparse_mode=3, - ) + # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. + # So two branches are maintained temporarily. + # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. + if self.use_torch_npu_lightning_indexer: + topk_indices, _ = torch_npu.npu_lightning_indexer( + query=q, + key=key, + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) + else: + topk_indices = torch.ops._C_ascend.npu_lightning_indexer( + query=q, + key=key, + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) return topk_indices def _init_o_proj_tp_full_params(self): diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 5034604e..307de36c 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -96,6 +96,11 @@ packed_modules_model_mapping: dict[str, dict[str, list[str]]] = { "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, + "glm_moe_dsa": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], + }, # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; # NOTE 2.The description file generated by the current msmodelslim tool does not have # MTP layer info. Please manually add it and set the value to FLOAT. diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 85930c0b..873cb543 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -36,7 +36,14 @@ class MtpProposer(EagleProposer): dummy_compute_logits=lambda hidden_states: None, is_profile=False, ) -> None: - if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch: + # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. + # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. + # TODO: this conditional check should be removed after bug fixing. + if ( + self.pcp_size * self.dcp_size == 1 + and not self.speculative_config.disable_padded_drafter_batch + and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): super().dummy_run( num_tokens, with_prefill, @@ -166,7 +173,14 @@ class MtpProposer(EagleProposer): scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, ) -> torch.Tensor: - if self.pcp_size * self.dcp_size == 1 and not self.speculative_config.disable_padded_drafter_batch: + # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. + # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. + # TODO: this conditional check should be removed after bug fixing. + if ( + self.pcp_size * self.dcp_size == 1 + and not self.speculative_config.disable_padded_drafter_batch + and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): draft_token_ids = super()._propose( target_token_ids, target_positions,