From 1de805ce0ac2d25d6b77755692f59f0dee4254ae Mon Sep 17 00:00:00 2001 From: jiaojiao <56385650+wenba0@users.noreply.github.com> Date: Tue, 24 Mar 2026 00:07:12 +0800 Subject: [PATCH] [Ops][Misc] Refactor and optimize CausalConv1d for Ascend (#7495) ### What this PR does / why we need it? During the prefill phase of Qwen3-Next and Qwen3.5, the `torch.ops._C_ascend.causal_conv1d_fn` operator exhibits significant performance bottlenecks. To address this, we have re-implemented the optimization using `torch.ops._C_ascend.npu_causal_conv1d_custom`. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? 1 accuracy test ``` [2026-03-20 16:44:22,961] [ais_bench] [INFO] Start launch task state board ... +-----------------------------+-----------+------------+-------------+----------+-------------------------------------------+---------------------+ | Task Name | Process | Progress | Time Cost | Status | Log Path | Extend Parameters | +=============================+===========+============+=============+==========+===========================================+=====================+ | vllm-api-general-chat/gsm8k | 2918978 | NA | 0:00:01 | finish | logs/eval/vllm-api-general-chat/gsm8k.out | None | +-----------------------------+-----------+------------+-------------+----------+-------------------------------------------+---------------------+ [2026-03-20 16:44:34,284] [ais_bench] [INFO] Evaluation tasks completed. [2026-03-20 16:44:34,287] [ais_bench] [INFO] Summarizing evaluation results... dataset version metric mode vllm-api-general-chat --------- --------- -------- ------ ----------------------- gsm8k 271d0b accuracy gen 96.21 ``` 2 ut modify test `pytest -sv /home/c30006096/vllm-ascend/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py::test_ascend_causal_conv1d` - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c Signed-off-by: wenba0 <3054239545@qq.com> Signed-off-by: jiaojiao <56385650+wenba0@users.noreply.github.com> --- csrc/causal_conv1d/op_host/CMakeLists.txt | 2 +- .../op_host/causal_conv1d_def.cpp | 24 +- .../op_host/causal_conv1d_infershape.cpp | 18 +- .../op_host/causal_conv1d_tiling.cpp | 599 ++++++++++++------ .../op_host/causal_conv1d_tiling.h | 60 -- csrc/causal_conv1d/op_host/math_util.h | 2 +- .../causal_conv1d/op_kernel/causal_conv1d.cpp | 36 +- csrc/causal_conv1d/op_kernel/causal_conv1d.h | 511 +++++++++------ .../op_kernel/causal_conv1d_common.h | 2 +- .../op_kernel/causal_conv1d_tiling_data.h | 49 ++ .../op_kernel/causal_conv1d_tiling_key.h | 2 +- csrc/torch_binding.cpp | 74 +-- csrc/torch_binding_meta.cpp | 24 +- .../triton/test_causal_conv1d.py | 20 +- vllm_ascend/patch/worker/patch_qwen3_5.py | 22 +- vllm_ascend/patch/worker/patch_qwen3_next.py | 16 +- 16 files changed, 907 insertions(+), 554 deletions(-) delete mode 100644 csrc/causal_conv1d/op_host/causal_conv1d_tiling.h create mode 100644 csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_data.h diff --git a/csrc/causal_conv1d/op_host/CMakeLists.txt b/csrc/causal_conv1d/op_host/CMakeLists.txt index 4644a8bd..41cf86fe 100644 --- a/csrc/causal_conv1d/op_host/CMakeLists.txt +++ b/csrc/causal_conv1d/op_host/CMakeLists.txt @@ -9,7 +9,7 @@ add_ops_compile_options( OP_NAME CausalConv1d - OPTIONS --cce-auto-sync=off + OPTIONS --cce-auto-sync=on -Wno-deprecated-declarations -Werror ) diff --git a/csrc/causal_conv1d/op_host/causal_conv1d_def.cpp b/csrc/causal_conv1d/op_host/causal_conv1d_def.cpp index 02a1c752..2e7049f4 100644 --- a/csrc/causal_conv1d/op_host/causal_conv1d_def.cpp +++ b/csrc/causal_conv1d/op_host/causal_conv1d_def.cpp @@ -42,19 +42,28 @@ public: .FormatList({ge::FORMAT_ND}) .AutoContiguous(); this->Input("queryStartLoc") - .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT32}) + .ParamType(OPTIONAL) + .DataTypeList({ge::DT_INT64}) .FormatList({ge::FORMAT_ND}) + .ValueDepend(OPTIONAL) .AutoContiguous(); this->Input("cacheIndices") - .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT32}) + .ParamType(OPTIONAL) + .DataTypeList({ge::DT_INT64}) .FormatList({ge::FORMAT_ND}) + .ValueDepend(OPTIONAL) .AutoContiguous(); - this->Input("hasInitialState") - .ParamType(REQUIRED) - .DataTypeList({ge::DT_BOOL}) + this->Input("initialStateMode") + .ParamType(OPTIONAL) + .DataTypeList({ge::DT_INT64}) .FormatList({ge::FORMAT_ND}) + .ValueDepend(OPTIONAL) + .AutoContiguous(); + this->Input("numAcceptedTokens") + .ParamType(OPTIONAL) + .DataTypeList({ge::DT_INT64}) + .FormatList({ge::FORMAT_ND}) + .ValueDepend(OPTIONAL) .AutoContiguous(); this->Output("y") @@ -65,6 +74,7 @@ public: this->Attr("activationMode").AttrType(OPTIONAL).Int(0); this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1); + this->Attr("runMode").AttrType(OPTIONAL).Int(0); OpAICoreConfig aicoreConfig; aicoreConfig.DynamicCompileStaticFlag(true) diff --git a/csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp b/csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp index 6c185ea0..b4cf380b 100644 --- a/csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp +++ b/csrc/causal_conv1d/op_host/causal_conv1d_infershape.cpp @@ -14,7 +14,7 @@ * \brief */ #include "register/op_impl_registry.h" -#include "error_log.h" +#include "log/log.h" using namespace ge; @@ -23,27 +23,19 @@ static constexpr int64_t IDX_0 = 0; static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context) { - // OPS_LOG_D(context->GetNodeName(), "Begin to do InferShapeCausalConv1d"); + OP_LOGD(context->GetNodeName(), "Begin to do InferShapeCausalConv1d"); - // get input shapes const gert::Shape* xShape = context->GetInputShape(IDX_0); OP_CHECK_NULL_WITH_CONTEXT(context, xShape); - // get output shapes gert::Shape* yShape = context->GetOutputShape(IDX_0); OP_CHECK_NULL_WITH_CONTEXT(context, yShape); - // 填充输出shape大小 - auto xShapeSize = xShape->GetDimNum(); - yShape->SetDimNum(xShapeSize); - for (size_t i = 0; i < xShapeSize; i++) { - int64_t dim = xShape->GetDim(i); - yShape->SetDim(i, dim); - } + *yShape = *xShape; - // OPS_LOG_D(context->GetNodeName(), "End to do InferShapeCausalConv1d"); + OP_LOGD(context->GetNodeName(), "End to do InferShapeCausalConv1d"); return GRAPH_SUCCESS; } IMPL_OP_INFERSHAPE(CausalConv1d).InferShape(InferShapeCausalConv1d); -} // namespace ops \ No newline at end of file +} // namespace ops diff --git a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp index fa8bd23f..0a4117bf 100644 --- a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp +++ b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp @@ -14,12 +14,12 @@ * \brief */ -// #include "error_log.h" -#include "log/ops_log.h" +//#include "log/log.h" +#include "error_log.h" #include "../tiling_base/tiling_templates_registry.h" #include "../tiling_base/tiling_util.h" #include "math_util.h" -#include "causal_conv1d_tiling.h" +#include "../op_kernel/causal_conv1d_tiling_data.h" #include "../op_kernel/causal_conv1d_tiling_key.h" #include @@ -35,12 +35,17 @@ constexpr uint32_t BIAS_INDEX = 2; constexpr uint32_t CONV_STATES_INDEX = 3; constexpr uint32_t QUERY_START_LOC_INDEX = 4; constexpr uint32_t CACHE_INDICES_INDEX = 5; -constexpr uint32_t HAS_INITIAL_STATE_INDEX = 6; +constexpr uint32_t INITIAL_STATE_MODE_INDEX = 6; +constexpr uint32_t NUM_ACCEPTED_TOKENS_INDEX = 7; constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0; constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1; +constexpr int32_t ATTR_RUN_MODE_INDEX = 2; - +struct CausalConv1dCompileInfo { + uint64_t ubSize = 0; + uint32_t coreNum = 0; +}; struct DimTileChoice { int64_t dimTileSize = 0; @@ -48,64 +53,81 @@ struct DimTileChoice { int64_t gridSize = 0; }; +static inline int64_t CeilDivInt64(int64_t x, int64_t y) +{ + return (x + y - 1) / y; +} + +static inline bool FitsInInt32(int64_t v) +{ + return v >= static_cast(std::numeric_limits::min()) && + v <= static_cast(std::numeric_limits::max()); +} + static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum) { + const int64_t candidates[] = {4096, 2048, 1024, 512, 384, 192}; - const int64_t candidates[] = {4096, 2048, 1024, 512,384}; - DimTileChoice bestOver; - int64_t bestOverGap = std::numeric_limits::max(); - DimTileChoice bestUnder; + auto ChooseOnce = [&](bool requireExactDiv) -> DimTileChoice { + DimTileChoice bestOver; + int64_t bestOverGap = std::numeric_limits::max(); + DimTileChoice bestUnder; - for (int64_t dimTileSize : candidates) { - if (dim % dimTileSize != 0) { - continue; - } - const int64_t blocksPerSeq = dim / dimTileSize; - const int64_t gridSize = batch * blocksPerSeq; - if (gridSize <= 0) { - continue; - } - - if (gridSize >= static_cast(coreNum)) { - const int64_t gap = gridSize - static_cast(coreNum); - if (gap < bestOverGap) { - bestOver.dimTileSize = dimTileSize; - bestOver.blocksPerSeq = blocksPerSeq; - bestOver.gridSize = gridSize; - bestOverGap = gap; + for (int64_t dimTileSize : candidates) { + if (dimTileSize <= 0) { + continue; } - } else if (gridSize > bestUnder.gridSize || - (gridSize == bestUnder.gridSize && dimTileSize < bestUnder.dimTileSize)) { + if (requireExactDiv && (dim % dimTileSize != 0)) { + continue; + } + const int64_t blocksPerSeq = requireExactDiv ? (dim / dimTileSize) : CeilDivInt64(dim, dimTileSize); + const int64_t gridSize = batch * blocksPerSeq; + if (gridSize <= 0) { + continue; + } + OP_LOGD(context, + "DimTile candidate[%s]: dimTileSize[%ld], blocksPerSeq[%ld], gridSize[%ld], coreNum[%u].", + requireExactDiv ? "exact" : "tail", + dimTileSize, blocksPerSeq, gridSize, coreNum); + if (gridSize >= static_cast(coreNum)) { + const int64_t gap = gridSize - static_cast(coreNum); + if (gap < bestOverGap) { +// bestOver = {dimTileSize, blocksPerSeq, gridSize}; + bestOver.dimTileSize = dimTileSize; + bestOver.blocksPerSeq = blocksPerSeq; + bestOver.gridSize = gridSize; + bestOverGap = gap; + } + } else if (gridSize > bestUnder.gridSize || + (gridSize == bestUnder.gridSize && dimTileSize < bestUnder.dimTileSize)) { bestUnder.dimTileSize = dimTileSize; bestUnder.blocksPerSeq = blocksPerSeq; bestUnder.gridSize = gridSize; + } } - } - DimTileChoice result = (bestOver.dimTileSize != 0) ? bestOver : bestUnder; + return (bestOver.dimTileSize != 0) ? bestOver : bestUnder; + }; + DimTileChoice result = ChooseOnce(true /*requireExactDiv*/); + if (result.dimTileSize == 0) { + result = ChooseOnce(false /*requireExactDiv*/); + } + OP_LOGD(context, + "DimTile chosen: dimTileSize[%ld], blocksPerSeq[%ld], gridSize[%ld].", + result.dimTileSize, result.blocksPerSeq, result.gridSize); return result; } static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum) { - auto compileInfoPtr = context->GetCompileInfo(); - if (compileInfoPtr != nullptr && compileInfoPtr->coreNum != 0 && compileInfoPtr->ubSize != 0) { - ubSize = compileInfoPtr->ubSize; - coreNum = compileInfoPtr->coreNum; - return ge::GRAPH_SUCCESS; - } fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo(); OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr); auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); coreNum = ascendcPlatform.GetCoreNumAiv(); - if(coreNum == 0) { - return ge::GRAPH_FAILED; - } + OP_CHECK_IF(coreNum == 0, OP_LOGE(context, "coreNum is 0"), return ge::GRAPH_FAILED); ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); - if(ubSize == 0) { - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; + OP_CHECK_IF(ubSize == 0, OP_LOGE(context, "ubSize is 0"), return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; } static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context) @@ -116,7 +138,8 @@ static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context) return ge::GRAPH_SUCCESS; } -static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId) +static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId, + int64_t& runMode) { auto attrs = context->GetAttrs(); OP_CHECK_NULL_WITH_CONTEXT(context, attrs); @@ -124,17 +147,26 @@ static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activ const int64_t* activationModePtr = attrs->GetAttrPointer(ATTR_ACTIVATION_MODE_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr); activationMode = *activationModePtr; - if(activationMode != 0 && activationMode != 1){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF( + activationMode != 0 && activationMode != 1, OP_LOGE(context, "activationMode only supports 0/1"), + return ge::GRAPH_FAILED); + const int64_t* padSlotIdPtr = attrs->GetAttrPointer(ATTR_PAD_SLOT_ID_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr); padSlotId = *padSlotIdPtr; + const int64_t* runModePtr = attrs->GetAttrPointer(ATTR_RUN_MODE_INDEX); + runMode = (runModePtr == nullptr) ? 0 : *runModePtr; + OP_CHECK_IF(runMode != 0 && runMode != 1, OP_LOGE(context, "runMode only supports 0/1"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; } + static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling) { + const bool isDecodeMode = (tiling.runMode == 1); + auto xShapePtr = context->GetInputShape(X_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr); auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape()); @@ -145,150 +177,362 @@ static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalCon int64_t batch = 0; int64_t inputMode = 0; - if (xShape.GetDimNum() == 2) { - inputMode = 0; - cuSeqlen = xShape.GetDim(0); - dim = xShape.GetDim(1); - seqLen = 0; - if(dim <= 0 || cuSeqlen < 0){ - return ge::GRAPH_FAILED; + if (xShape.GetDimNum() == 2) { + if (isDecodeMode) { + inputMode = 2; + batch = xShape.GetDim(0); + dim = xShape.GetDim(1); + seqLen = 1; + cuSeqlen = batch; + OP_CHECK_IF(batch <= 0 || dim <= 0, + OP_LOGE(context, "invalid x shape for 2D decode mode"), + return ge::GRAPH_FAILED); + } else { + inputMode = 0; + cuSeqlen = xShape.GetDim(0); + dim = xShape.GetDim(1); + seqLen = 0; + OP_CHECK_IF(dim <= 0 || cuSeqlen < 0, + OP_LOGE(context, "invalid x shape for 2D varlen mode"), + return ge::GRAPH_FAILED); } - } else if (xShape.GetDimNum() == 3) { inputMode = 1; batch = xShape.GetDim(0); seqLen = xShape.GetDim(1); dim = xShape.GetDim(2); cuSeqlen = batch * seqLen; - if(batch <= 0 || dim <= 0 || seqLen <= 0){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF(batch <= 0 || dim <= 0 || seqLen <= 0, OP_LOGE(context, "invalid x shape for 3D batch mode"), return ge::GRAPH_FAILED); } else { + OP_LOGE(context, "x must be 2D (cu_seqlen, dim) or 3D (batch, seqlen, dim)"); return ge::GRAPH_FAILED; } auto wShapePtr = context->GetInputShape(WEIGHT_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr); auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape()); - if(wShape.GetDimNum() != 2){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF(wShape.GetDimNum() != 2, OP_LOGE(context, "weight must be 2D: (width, dim)"), return ge::GRAPH_FAILED); const int64_t width = wShape.GetDim(0); const int64_t wDim = wShape.GetDim(1); - if(wDim != dim){ - return ge::GRAPH_FAILED; - } - if(width != 4){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF(wDim != dim, OP_LOGE(context, "weight.shape[1] must equal dim"), return ge::GRAPH_FAILED); + OP_CHECK_IF(width < 2 || width > 4, + OP_LOGE(context, "Only support width in [2,4] now, actually is %ld.", width), + return ge::GRAPH_FAILED); + OP_CHECK_IF(dim % 16 != 0, + OP_LOGE(context, "dim must be a multiple of 16 for fp16/bf16 alignment, actually is %ld.", dim), + return ge::GRAPH_FAILED); auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr); auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape()); - if(sShape.GetDimNum() != 3){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF( + sShape.GetDimNum() != 3, OP_LOGE(context, "convStates must be 3D: (num_cache_lines, state_len, dim)"), + return ge::GRAPH_FAILED); const int64_t numCacheLines = sShape.GetDim(0); const int64_t stateLen = sShape.GetDim(1); const int64_t sDim = sShape.GetDim(2); - if(numCacheLines <= 0){ - return ge::GRAPH_FAILED;} - if(sDim != dim){ - return ge::GRAPH_FAILED;} - if(stateLen < (width - 1)){ - return ge::GRAPH_FAILED;} + OP_CHECK_IF(numCacheLines <= 0, OP_LOGE(context, "convStates.shape[0] (num_cache_lines) must be > 0"), return ge::GRAPH_FAILED); + OP_CHECK_IF(sDim != dim, OP_LOGE(context, "convStates.shape[2] must equal dim"), return ge::GRAPH_FAILED); + OP_CHECK_IF(stateLen < (width - 1), OP_LOGE(context, "convStates.shape[1] must be >= width-1"), return ge::GRAPH_FAILED); - auto qslShapePtr = context->GetInputShape(QUERY_START_LOC_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context, qslShapePtr); - auto qslShape = EnsureNotScalar(qslShapePtr->GetStorageShape()); - if(qslShape.GetDimNum() != 1){ - return ge::GRAPH_FAILED;} - const int64_t qslSize = qslShape.GetDim(0); - if(qslSize < 1){ - return ge::GRAPH_FAILED;} + auto qslShapePtr = context->GetOptionalInputShape(QUERY_START_LOC_INDEX); + const gert::CompileTimeTensorDesc* qslDesc = context->GetOptionalInputDesc(QUERY_START_LOC_INDEX); + bool qslAbsent = true; + int64_t qslSize = 0; + if (qslShapePtr != nullptr) { + const auto qslStorageShape = qslShapePtr->GetStorageShape(); + const int64_t qslDimNum = qslStorageShape.GetDimNum(); + qslAbsent = (qslDimNum == 0) || (qslDimNum == 1 && qslStorageShape.GetDim(0) <= 0); + + if (!qslAbsent) { + auto qslShape = EnsureNotScalar(qslStorageShape); + OP_CHECK_IF(qslShape.GetDimNum() != 1, OP_LOGE(context, "queryStartLoc must be 1D"), + return ge::GRAPH_FAILED); + qslSize = qslShape.GetDim(0); + OP_CHECK_IF(qslSize < 1, OP_LOGE(context, "queryStartLoc.size must be >= 1"), + return ge::GRAPH_FAILED); + + OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc); + OP_CHECK_IF(qslDesc->GetDataType() != ge::DT_INT64, + OP_LOGE(context, "queryStartLoc dtype must be int64"), + return ge::GRAPH_FAILED); + } + } + + if (qslAbsent) { + OP_CHECK_IF(inputMode == 0, + OP_LOGE(context, "queryStartLoc is required in 2D varlen mode (inputMode=0)"), + return ge::GRAPH_FAILED); + qslSize = batch + 1; + } + + OP_CHECK_IF(cuSeqlen > static_cast(std::numeric_limits::max()), + OP_LOGE(context, "cuSeqlen is too large for int32 indexing, got %ld", cuSeqlen), + return ge::GRAPH_FAILED); + + const int64_t* qslData = nullptr; + if (!qslAbsent) { + const gert::Tensor* qslTensor = context->GetOptionalInputTensor(QUERY_START_LOC_INDEX); + qslData = (qslTensor != nullptr) ? qslTensor->GetData() : nullptr; + if (qslData != nullptr) { + OP_CHECK_IF(qslData[0] != 0, OP_LOGE(context, "queryStartLoc[0] must be 0"), + return ge::GRAPH_FAILED); + OP_CHECK_IF(qslData[qslSize - 1] != cuSeqlen, + OP_LOGE(context, "queryStartLoc[last] must equal cuSeqlen, got %ld vs %ld", + qslData[qslSize - 1], cuSeqlen), + return ge::GRAPH_FAILED); + for (int64_t i = 0; i + 1 < qslSize; ++i) { + const int64_t cur = qslData[i]; + const int64_t nxt = qslData[i + 1]; + OP_CHECK_IF(cur < 0 || cur > cuSeqlen, + OP_LOGE(context, "queryStartLoc[%ld] out of range: %ld (cuSeqlen=%ld)", i, cur, cuSeqlen), + return ge::GRAPH_FAILED); + OP_CHECK_IF(nxt < 0 || nxt > cuSeqlen, + OP_LOGE(context, "queryStartLoc[%ld] out of range: %ld (cuSeqlen=%ld)", i + 1, nxt, cuSeqlen), + return ge::GRAPH_FAILED); + OP_CHECK_IF(nxt < cur, + OP_LOGE(context, + "queryStartLoc must be non-decreasing, got queryStartLoc[%ld]=%ld queryStartLoc[%ld]=%ld", + i, cur, i + 1, nxt), + return ge::GRAPH_FAILED); + } + } + } + + if (!qslAbsent && isDecodeMode && inputMode == 2) { + const int64_t batchFromQsl = qslSize - 1; + if (batchFromQsl != batch) { + inputMode = 0; + cuSeqlen = xShape.GetDim(0); + batch = batchFromQsl; + seqLen = 0; + OP_CHECK_IF(dim <= 0 || cuSeqlen < 0 || batch < 0, + OP_LOGE(context, "invalid x/queryStartLoc shapes for 2D varlen decode mode"), + return ge::GRAPH_FAILED); + } + } if (inputMode == 0) { batch = qslSize - 1; } - if (inputMode == 1) { - if(qslSize != batch + 1){ - return ge::GRAPH_FAILED; + if (!qslAbsent && (inputMode == 1 || inputMode == 2)) { + OP_CHECK_IF(qslSize != batch + 1, OP_LOGE(context, "queryStartLoc.size must equal batch + 1"), + return ge::GRAPH_FAILED); + } + + if (isDecodeMode) { + const int64_t decodeSeqLen = (inputMode == 1) ? seqLen : 1; + OP_CHECK_IF(decodeSeqLen < 1, + OP_LOGE(context, "decode mode requires seqlen >= 1, actual is %ld", decodeSeqLen), + return ge::GRAPH_FAILED); + } + + tiling.hasCacheIndices = 0; + bool ciAbsent = true; + auto ciShapePtr = context->GetOptionalInputShape(CACHE_INDICES_INDEX); + if (ciShapePtr != nullptr) { + const auto ciStorageShape = ciShapePtr->GetStorageShape(); + const int64_t ciDimNum = ciStorageShape.GetDimNum(); + ciAbsent = (ciDimNum == 0) || (ciDimNum == 1 && ciStorageShape.GetDim(0) <= 0); + if (!ciAbsent) { + auto ciShape = EnsureNotScalar(ciStorageShape); + OP_CHECK_IF(ciShape.GetDimNum() != 1, OP_LOGE(context, "cacheIndices must be 1D"), return ge::GRAPH_FAILED); + OP_CHECK_IF(ciShape.GetDim(0) != batch, OP_LOGE(context, "cacheIndices.size must equal batch"), return ge::GRAPH_FAILED); + tiling.hasCacheIndices = 1; + + const gert::Tensor* ciTensor = context->GetOptionalInputTensor(CACHE_INDICES_INDEX); + const int64_t* ciData = (ciTensor != nullptr) ? ciTensor->GetData() : nullptr; + if (ciData != nullptr) { + for (int64_t i = 0; i < batch; ++i) { + const int64_t v = ciData[i]; + if (v == tiling.padSlotId) { + continue; + } + OP_CHECK_IF(!FitsInInt32(v), + OP_LOGE(context, "cacheIndices[%ld]=%ld does not fit int32", i, v), + return ge::GRAPH_FAILED); + OP_CHECK_IF(v < 0 || v >= numCacheLines, + OP_LOGE(context, + "cacheIndices[%ld]=%ld out of range [0, num_cache_lines=%ld), padSlotId=%ld", + i, v, numCacheLines, tiling.padSlotId), + return ge::GRAPH_FAILED); + } + } + } + } + if (ciAbsent) { + OP_CHECK_IF(numCacheLines < batch, + OP_LOGE(context, + "cacheIndices is absent, requires convStates.shape[0] (num_cache_lines) >= batch for identity mapping, got num_cache_lines=%ld batch=%ld", + numCacheLines, batch), + return ge::GRAPH_FAILED); + } + + tiling.hasInitialStateMode = 0; + auto ismShapePtr = context->GetOptionalInputShape(INITIAL_STATE_MODE_INDEX); + if (ismShapePtr != nullptr) { + const auto ismStorageShape = ismShapePtr->GetStorageShape(); + const int64_t ismDimNum = ismStorageShape.GetDimNum(); + const bool ismAbsent = (ismDimNum == 0) || (ismDimNum == 1 && ismStorageShape.GetDim(0) <= 0); + if (!ismAbsent) { + auto ismShape = EnsureNotScalar(ismStorageShape); + OP_CHECK_IF(ismShape.GetDimNum() != 1, OP_LOGE(context, "initialStateMode must be 1D"), + return ge::GRAPH_FAILED); + OP_CHECK_IF(ismShape.GetDim(0) != batch, OP_LOGE(context, "initialStateMode.size must equal batch"), + return ge::GRAPH_FAILED); + tiling.hasInitialStateMode = 1; + + const gert::Tensor* ismTensor = context->GetOptionalInputTensor(INITIAL_STATE_MODE_INDEX); + const int64_t* ismData = (ismTensor != nullptr) ? ismTensor->GetData() : nullptr; + if (ismData != nullptr) { + for (int64_t i = 0; i < batch; ++i) { + const int64_t v = ismData[i]; + OP_CHECK_IF(v != 0 && v != 1, + OP_LOGE(context, "initialStateMode[%ld]=%ld is invalid (only supports 0/1)", i, v), + return ge::GRAPH_FAILED); + } + } } } - auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr); - auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape()); - if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;} - if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;} - - auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr); - auto hisShape = EnsureNotScalar(hisShapePtr->GetStorageShape()); - if(hisShape.GetDimNum() != 1){ - return ge::GRAPH_FAILED;} - if(hisShape.GetDim(0) != batch){ - return ge::GRAPH_FAILED;} + tiling.hasNumAcceptedTokens = 0; + auto natShapePtr = context->GetOptionalInputShape(NUM_ACCEPTED_TOKENS_INDEX); + if (natShapePtr != nullptr) { + const auto natStorageShape = natShapePtr->GetStorageShape(); + const int64_t natDimNum = natStorageShape.GetDimNum(); + const bool natAbsent = (natDimNum == 0) || (natDimNum == 1 && natStorageShape.GetDim(0) <= 0); + if (!natAbsent) { + OP_CHECK_IF(!isDecodeMode, + OP_LOGE(context, "numAcceptedTokens is only supported in runMode=1 (decode/update)"), + return ge::GRAPH_FAILED); + auto natShape = EnsureNotScalar(natStorageShape); + OP_CHECK_IF(natShape.GetDimNum() != 1, OP_LOGE(context, "numAcceptedTokens must be 1D"), return ge::GRAPH_FAILED); + OP_CHECK_IF(natShape.GetDim(0) != batch, OP_LOGE(context, "numAcceptedTokens.size must equal batch"), return ge::GRAPH_FAILED); - tiling.set_hasBias(0); - auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX); - if (biasShapePtr != nullptr && biasShapePtr->GetStorageShape().GetDimNum() != 0) { - auto biasShape = EnsureNotScalar(biasShapePtr->GetStorageShape()); - if(biasShape.GetDimNum() != 1){ - return ge::GRAPH_FAILED;} - if(biasShape.GetDim(0) != dim){ - return ge::GRAPH_FAILED;} - tiling.set_hasBias(1); + if (inputMode == 1) { + const int64_t reqStateLen = (width - 1) + (seqLen - 1); + OP_CHECK_IF( + stateLen < reqStateLen, + OP_LOGE(context, + "spec decode requires stateLen >= (width-1) + (seqlen-1), got stateLen=%ld req=%ld", + stateLen, reqStateLen), + return ge::GRAPH_FAILED); + } + + const gert::Tensor* natTensor = context->GetOptionalInputTensor(NUM_ACCEPTED_TOKENS_INDEX); + const int64_t* natData = (natTensor != nullptr) ? natTensor->GetData() : nullptr; + if (natData != nullptr) { + for (int64_t i = 0; i < batch; ++i) { + const int64_t a = natData[i]; + OP_CHECK_IF(a < 0, + OP_LOGE(context, "numAcceptedTokens[%ld]=%ld is invalid (must be >= 0)", i, a), + return ge::GRAPH_FAILED); + OP_CHECK_IF(!FitsInInt32(a), + OP_LOGE(context, "numAcceptedTokens[%ld]=%ld does not fit int32", i, a), + return ge::GRAPH_FAILED); + + if (inputMode == 2) { + OP_CHECK_IF(a > 1, + OP_LOGE(context, + "numAcceptedTokens[%ld]=%ld exceeds decode 2D token count (1)", i, a), + return ge::GRAPH_FAILED); + } else if (inputMode == 1) { + OP_CHECK_IF(a > seqLen, + OP_LOGE(context, + "numAcceptedTokens[%ld]=%ld exceeds seqlen=%ld in 3D update", i, a, seqLen), + return ge::GRAPH_FAILED); + } else if (inputMode == 0) { + if (qslData != nullptr) { + const int64_t lenI = qslData[i + 1] - qslData[i]; + OP_CHECK_IF(a > lenI, + OP_LOGE(context, + "numAcceptedTokens[%ld]=%ld exceeds varlen segment length=%ld", + i, a, lenI), + return ge::GRAPH_FAILED); + } + } + } + } + + tiling.hasNumAcceptedTokens = 1; + } + } + + tiling.hasBias = 0; + auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX); + if (biasShapePtr != nullptr) { + const auto biasStorageShape = biasShapePtr->GetStorageShape(); + const int64_t biasDimNum = biasStorageShape.GetDimNum(); + const bool biasAbsent = (biasDimNum == 0) || (biasDimNum == 1 && biasStorageShape.GetDim(0) <= 0); + if (!biasAbsent) { + auto biasShape = EnsureNotScalar(biasStorageShape); + OP_CHECK_IF(biasShape.GetDimNum() != 1, OP_LOGE(context, "bias must be 1D: (dim,)"), return ge::GRAPH_FAILED); + OP_CHECK_IF(biasShape.GetDim(0) != dim, OP_LOGE(context, "bias.size must equal dim"), return ge::GRAPH_FAILED); + tiling.hasBias = 1; + } } const std::set supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16}; auto xDesc = context->GetInputDesc(X_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, xDesc); const ge::DataType xDtype = xDesc->GetDataType(); - if(supportedXDtype.count(xDtype) == 0){ - return ge::GRAPH_FAILED;} + OP_CHECK_IF(supportedXDtype.count(xDtype) == 0, OP_LOGE(context, "x dtype only supports bf16/fp16"), return ge::GRAPH_FAILED); auto wDesc = context->GetInputDesc(WEIGHT_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, wDesc); - if(wDesc->GetDataType() != xDtype){ - return ge::GRAPH_FAILED;} + OP_CHECK_IF(wDesc->GetDataType() != xDtype, OP_LOGE(context, "weight dtype must equal x dtype"), return ge::GRAPH_FAILED); - if (tiling.get_hasBias() == 1) { + if (tiling.hasBias == 1) { auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc); - if(biasDesc->GetDataType() != xDtype){ - return ge::GRAPH_FAILED;} + OP_CHECK_IF(biasDesc->GetDataType() != xDtype, OP_LOGE(context, "bias dtype must equal x dtype"), return ge::GRAPH_FAILED); } auto sDesc = context->GetInputDesc(CONV_STATES_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, sDesc); - if(sDesc->GetDataType() != xDtype){ - return ge::GRAPH_FAILED;} + OP_CHECK_IF(sDesc->GetDataType() != xDtype, OP_LOGE(context, "convStates dtype must equal x dtype"), return ge::GRAPH_FAILED); - auto qslDesc = context->GetInputDesc(QUERY_START_LOC_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc); - if(qslDesc->GetDataType() != ge::DT_INT32){ - return ge::GRAPH_FAILED;} + if (!qslAbsent) { + auto qslDesc2 = context->GetOptionalInputDesc(QUERY_START_LOC_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc2); + OP_CHECK_IF(qslDesc2->GetDataType() != ge::DT_INT64, OP_LOGE(context, "queryStartLoc dtype must be int64"), + return ge::GRAPH_FAILED); + } - auto ciDesc = context->GetInputDesc(CACHE_INDICES_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc); - if(ciDesc->GetDataType() != ge::DT_INT32){ - return ge::GRAPH_FAILED;} + if (tiling.hasCacheIndices == 1) { + auto ciDesc = context->GetOptionalInputDesc(CACHE_INDICES_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc); + OP_CHECK_IF(ciDesc->GetDataType() != ge::DT_INT64, OP_LOGE(context, "cacheIndices dtype must be int64"), + return ge::GRAPH_FAILED); + } - auto hisDesc = context->GetInputDesc(HAS_INITIAL_STATE_INDEX); - OP_CHECK_NULL_WITH_CONTEXT(context, hisDesc); - if(hisDesc->GetDataType() != ge::DT_BOOL){ - return ge::GRAPH_FAILED;} + if (tiling.hasInitialStateMode == 1) { + auto ismDesc = context->GetOptionalInputDesc(INITIAL_STATE_MODE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context, ismDesc); + OP_CHECK_IF(ismDesc->GetDataType() != ge::DT_INT64, + OP_LOGE(context, "initialStateMode dtype must be int64"), + return ge::GRAPH_FAILED); + } - tiling.set_dim(dim); - tiling.set_cuSeqlen(cuSeqlen); - tiling.set_seqLen(seqLen); - tiling.set_inputMode(inputMode); - tiling.set_width(width); - tiling.set_stateLen(stateLen); - tiling.set_numCacheLines(numCacheLines); - tiling.set_batch(batch); + if (tiling.hasNumAcceptedTokens == 1) { + OP_CHECK_IF(width != 4, + OP_LOGE(context, "numAcceptedTokens is only supported for width=4 currently"), + return ge::GRAPH_FAILED); + auto natDesc = context->GetOptionalInputDesc(NUM_ACCEPTED_TOKENS_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context, natDesc); + OP_CHECK_IF(natDesc->GetDataType() != ge::DT_INT64, OP_LOGE(context, "numAcceptedTokens dtype must be int64"), + return ge::GRAPH_FAILED); + } + + tiling.dim = dim; + tiling.cuSeqlen = cuSeqlen; + tiling.seqLen = seqLen; + tiling.inputMode = inputMode; + tiling.width = width; + tiling.stateLen = stateLen; + tiling.numCacheLines = numCacheLines; + tiling.batch = batch; return ge::GRAPH_SUCCESS; } @@ -296,70 +540,61 @@ static ge::graphStatus CausalConv1dTilingFunc(gert::TilingContext* context) { uint64_t ubSize; uint32_t coreNum; - if( GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF( + GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetPlatformInfo error"), + return ge::GRAPH_FAILED); - if(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS){ - return ge::GRAPH_FAILED; - } - CausalConv1dTilingData tilingData; + OP_CHECK_IF( + GetWorkspaceSize(context) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetWorkspaceSize error"), + return ge::GRAPH_FAILED); - int64_t activationMode = 0; - int64_t padSlotId = -1; - if(GetAttrsInfo(context, activationMode, padSlotId) != ge::GRAPH_SUCCESS){ - return ge::GRAPH_FAILED; - } - tilingData.set_activationMode(activationMode); - tilingData.set_padSlotId(padSlotId); + CausalConv1dTilingData* tiling = context->GetTilingData(); + OP_CHECK_NULL_WITH_CONTEXT(context, tiling); + OP_CHECK_IF( + memset_s(tiling, sizeof(CausalConv1dTilingData), 0, sizeof(CausalConv1dTilingData)) != EOK, + OP_LOGE(context, "set tiling data error"), return ge::GRAPH_FAILED); - if( GetShapeDtypeInfo(context, tilingData) != ge::GRAPH_SUCCESS){ - return ge::GRAPH_FAILED; - } + OP_CHECK_IF( + GetAttrsInfo(context, tiling->activationMode, tiling->padSlotId, tiling->runMode) != ge::GRAPH_SUCCESS, + OP_LOGE(context, "GetAttrsInfo error"), return ge::GRAPH_FAILED); + + OP_CHECK_IF( + GetShapeDtypeInfo(context, *tiling) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetShapeDtypeInfo error"), + return ge::GRAPH_FAILED); + + const int64_t dim = tiling->dim; + const int64_t batch = tiling->batch; + OP_CHECK_IF(dim <= 0 || batch <= 0, OP_LOGE(context, "dim/batch must be positive"), return ge::GRAPH_FAILED); - const int64_t dim = tilingData.get_dim(); - const int64_t batch = tilingData.get_batch(); - if(dim <= 0 || batch <= 0){ - return ge::GRAPH_FAILED; - } const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum); + OP_CHECK_IF(choice.dimTileSize <= 0 || choice.blocksPerSeq <= 0 || choice.gridSize <= 0, + OP_LOGE(context, "invalid dim_tile_size selection"), + return ge::GRAPH_FAILED); + const uint32_t blockDim = (choice.gridSize < static_cast(coreNum)) ? static_cast(choice.gridSize) : coreNum; + + OP_LOGD(context, + "Tiling result: batch[%ld], dim[%ld], dimTileSize[%ld], blocksPerSeq[%ld], gridSize[%ld], blockDim[%u], coreNum[%u].", + batch, dim, choice.dimTileSize, choice.blocksPerSeq, choice.gridSize, blockDim, coreNum); + context->SetBlockDim(blockDim); - tilingData.set_dimTileSize(choice.dimTileSize); - tilingData.set_blocksPerSeq(choice.blocksPerSeq); + tiling->dimTileSize = choice.dimTileSize; + tiling->blocksPerSeq = choice.blocksPerSeq; const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT); context->SetTilingKey(tilingKey); - - tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); return ge::GRAPH_SUCCESS; } - - static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context) { - auto platformInfoPtr = context->GetPlatformInfo(); - OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr); - auto compileInfoPtr = context->GetCompiledInfo(); - OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr); - - auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr); - compileInfoPtr->coreNum = static_cast(ascendcPlatform.GetCoreNumAiv()); - if(compileInfoPtr->coreNum == 0){ - return ge::GRAPH_FAILED; - } - ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize); - if(compileInfoPtr->ubSize == 0){ - return ge::GRAPH_FAILED; - } + OP_LOGD(context, "Enter TilingParseForCausalConv1d."); return ge::GRAPH_SUCCESS; } IMPL_OP_OPTILING(CausalConv1d) .Tiling(CausalConv1dTilingFunc) .TilingParse(TilingParseForCausalConv1d); -} // namespace optiling \ No newline at end of file +} // namespace optiling diff --git a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h deleted file mode 100644 index 28e74e5b..00000000 --- a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * This program is free software, you can redistribute it and/or modify it. - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING - * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file causal_conv1d_tiling_data.h - * \brief - */ - -#ifndef ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H -#define ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H - -#include - -// #include "register/tilingdata_base.h" -// #include "tiling/tiling_api.h" -#include "register/tilingdata_base.h" -#include "error_log.h" -#include "register/op_impl_registry.h" -#include "tiling/platform/platform_ascendc.h" -#include "platform/platform_infos_def.h" -namespace optiling { - -BEGIN_TILING_DATA_DEF(CausalConv1dTilingData) - TILING_DATA_FIELD_DEF(int64_t, dim); - TILING_DATA_FIELD_DEF(int64_t, cuSeqlen); - TILING_DATA_FIELD_DEF(int64_t, seqLen); - TILING_DATA_FIELD_DEF(int64_t, inputMode); - - TILING_DATA_FIELD_DEF(int64_t, width); - - TILING_DATA_FIELD_DEF(int64_t, stateLen); - TILING_DATA_FIELD_DEF(int64_t, numCacheLines); - - TILING_DATA_FIELD_DEF(int64_t, batch); - - TILING_DATA_FIELD_DEF(int64_t, activationMode); - TILING_DATA_FIELD_DEF(int64_t, padSlotId); - - TILING_DATA_FIELD_DEF(int64_t, hasBias); - - TILING_DATA_FIELD_DEF(int64_t, dimTileSize); - TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq); -END_TILING_DATA_DEF; -struct CausalConv1dCompileInfo { - uint64_t ubSize = 0; - uint32_t coreNum = 0; -}; -REGISTER_TILING_DATA_CLASS(CausalConv1d, CausalConv1dTilingData) - -} // namespace optiling - -#endif // ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H \ No newline at end of file diff --git a/csrc/causal_conv1d/op_host/math_util.h b/csrc/causal_conv1d/op_host/math_util.h index edc1c8ea..b5f71430 100644 --- a/csrc/causal_conv1d/op_host/math_util.h +++ b/csrc/causal_conv1d/op_host/math_util.h @@ -58,4 +58,4 @@ public: static std::pair DivideIntoMainAndTail(int32_t num, int32_t divisor); }; } // namespace matmul_tiling -#endif // _MATH_UTIL_H_ +#endif // _MATH_UTIL_H_ \ No newline at end of file diff --git a/csrc/causal_conv1d/op_kernel/causal_conv1d.cpp b/csrc/causal_conv1d/op_kernel/causal_conv1d.cpp index de9308b6..798c3e51 100644 --- a/csrc/causal_conv1d/op_kernel/causal_conv1d.cpp +++ b/csrc/causal_conv1d/op_kernel/causal_conv1d.cpp @@ -18,13 +18,16 @@ namespace { - template +// NOTE: +// Dtype is provided via AscendC compile macros (e.g. DTYPE_X / ORIG_DTYPE_X), so tiling key does not need to carry dtype. + +template __aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, - GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState, - GM_ADDR y, const NsCausalConv1d::CausalConv1dTilingData* tilingData) + GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR initialStateMode, + GM_ADDR numAcceptedTokens, GM_ADDR y, const CausalConv1dTilingData* tilingData) { NsCausalConv1d::CausalConv1d op; - op.Init(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, tilingData); + op.Init(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, tilingData); op.Process(); } @@ -32,27 +35,24 @@ __aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, template __global__ __aicore__ void causal_conv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, - GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState, - GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling) + GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR initialStateMode, + GM_ADDR numAcceptedTokens, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling) { - REGISTER_TILING_DEFAULT( NsCausalConv1d::CausalConv1dTilingData); - // GET_TILING_DATA_WITH_STRUCT( NsCausalConv1d::CausalConv1dTilingData, tilingData, tiling); - GET_TILING_DATA(tilingData, tiling); + REGISTER_TILING_DEFAULT(CausalConv1dTilingData); + GET_TILING_DATA_WITH_STRUCT(CausalConv1dTilingData, tilingData, tiling); + #if defined(ORIG_DTYPE_X) #if (ORIG_DTYPE_X == DT_FLOAT16) - RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData); + RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData); #elif (ORIG_DTYPE_X == DT_BF16) - RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData); - #elif (ORIG_DTYPE_X == DT_FLOAT) - RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData); + RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData); #endif #else #if (DTYPE_X == DT_FLOAT16) - RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData); + RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData); #elif (DTYPE_X == DT_BF16) - RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData); - #elif (DTYPE_X == DT_FLOAT) - RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData); + RunCausalConv1d(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData); #endif #endif -} \ No newline at end of file +} + diff --git a/csrc/causal_conv1d/op_kernel/causal_conv1d.h b/csrc/causal_conv1d/op_kernel/causal_conv1d.h index 3407dd37..33622cf1 100644 --- a/csrc/causal_conv1d/op_kernel/causal_conv1d.h +++ b/csrc/causal_conv1d/op_kernel/causal_conv1d.h @@ -12,90 +12,23 @@ /*! * \file causal_conv1d.h * \brief CausalConv1D (prefill/extend) AscendC kernel implementation. + * */ #ifndef CAUSAL_CONV1D_H #define CAUSAL_CONV1D_H #include "kernel_operator.h" -// #include "kernel_tiling/kernel_tiling.h" +#include "kernel_tiling/kernel_tiling.h" +#include "causal_conv1d_tiling_data.h" #include "causal_conv1d_tiling_key.h" #include "causal_conv1d_common.h" -// #define ENABLE_CAUSAL_CONV1D_DEBUG - -// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG -// #define CCONV_PRINTF(fmt, ...) printf(fmt, ##__VA_ARGS__) -// #else -// #define CCONV_PRINTF(fmt, ...) -// #endif - -// #define CCONV_PRINT_IF(cond, fmt, ...) \ -// do { \ -// if (cond) { \ -// CCONV_PRINTF(fmt, ##__VA_ARGS__); \ -// } \ -// } while (0) - -// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG - -// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \ -// do { \ -// if (cond) { \ -// DumpTensor(tensor, __LINE__, size); \ -// } \ -// } while (0) -// #else -constexpr int32_t CCONV_DBG_SEQ = -1; -constexpr int32_t CCONV_DBG_C0 = -1; -constexpr int32_t CCONV_DBG_MAX_TOKENS = 0; -constexpr int32_t CCONV_DBG_VERBOSE_TOKENS = 0; -constexpr int32_t CCONV_DBG_DUMP_SIZE = 0; -constexpr bool CCONV_DBG_PRINT_SYNC = false; -constexpr bool CCONV_DBG_DUMP_WEIGHTS = false; -constexpr bool CCONV_DBG_DUMP_BIAS = false; -constexpr bool CCONV_DBG_DUMP_INIT_RING = false; -constexpr bool CCONV_DBG_DUMP_RUNSEQ = false; -constexpr bool CCONV_DBG_DUMP_PREFETCH = false; -constexpr bool CCONV_DBG_DUMP_STATE = false; - -// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \ -// do { \ -// } while (0) -// #endif -using namespace AscendC; namespace NsCausalConv1d { + +using namespace AscendC; using namespace NsCausalConv1dCommon; -#ifndef CAUSAL_CONV1D_TILING_DATA_H_ -#define CAUSAL_CONV1D_TILING_DATA_H_ - -struct CausalConv1dTilingData { - int64_t dim; - int64_t cuSeqlen; - int64_t seqLen; - int64_t inputMode; - - int64_t width; - - int64_t stateLen; - int64_t numCacheLines; - - int64_t batch; - - // attrs - int64_t activationMode; // 0: none, 1: silu/swish - int64_t padSlotId; // default -1 - - // optional inputs - int64_t hasBias; // 0/1 - - // Channel-wise tiling - int64_t dimTileSize; - int64_t blocksPerSeq; -}; -#endif // CAUSAL_CONV1D_TILING_DATA_H_ - template class CausalConv1d { @@ -103,18 +36,19 @@ public: __aicore__ inline CausalConv1d() = default; __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, GM_ADDR queryStartLoc, - GM_ADDR cacheIndices, GM_ADDR hasInitialState, GM_ADDR y - , - const CausalConv1dTilingData* tilingData); + GM_ADDR cacheIndices, GM_ADDR initialStateMode, GM_ADDR numAcceptedTokens, GM_ADDR y, + const CausalConv1dTilingData* tilingData); __aicore__ inline void Process(); private: - __aicore__ inline void LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg); - __aicore__ inline void InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len, - int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg); - __aicore__ inline void RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg); - __aicore__ inline void WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0, - int32_t dimTileSize, int32_t dim, bool dbg); + __aicore__ inline void LoadWeightAndBias(int32_t c0, int32_t dimTileSize); + __aicore__ inline void InitRing(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset, int32_t start, int32_t len, + int32_t c0, int32_t dimTileSize, int32_t dim); + __aicore__ inline void RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim); + __aicore__ inline void WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim); + __aicore__ inline void WriteBackStateSpec(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset, + int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, + int32_t dim); __aicore__ inline void AllocEvents(); __aicore__ inline void ReleaseEvents(); @@ -124,34 +58,43 @@ private: TBuf outBuf; TBuf calcBuf; - TEventID tempVToMte2Event_; - TEventID tempMte2ToVEvent_; - TEventID inputMte2ToVEvent_; + TEventID weightBiasMte2ToVEvent_; + TEventID stateMte2ToVEvent_; + TEventID inputMte2ToVEvent_[RING_SLOTS]; + TEventID inputVToMte2Event_; TEventID outMte3ToVEvent_[2]; TEventID outVToMte3Event_[2]; + TEventID stateWritebackMte3ToVEvent_; + TEventID stateWritebackMte3ToMte2Event_; + TEventID specWritebackMte2ToMte3Event_[2]; + TEventID specWritebackMte3ToMte2Event_[2]; GlobalTensor xGm; GlobalTensor weightGm; GlobalTensor biasGm; GlobalTensor convStatesGm; - GlobalTensor queryStartLocGm; - GlobalTensor cacheIndicesGm; - GlobalTensor hasInitialStateGm; + GlobalTensor queryStartLocGm; + GlobalTensor cacheIndicesGm; + GlobalTensor initialStateModeGm; + GlobalTensor numAcceptedTokensGm; GlobalTensor yGm; - const CausalConv1dTilingData* tilingData_ {nullptr}; + const CausalConv1dTilingData* tilingData_ {nullptr}; + + bool weightCacheValid_ {false}; + int32_t cachedC0_ {-1}; + int32_t cachedDimTileSize_ {-1}; }; template __aicore__ inline void CausalConv1d::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, - GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState, - GM_ADDR y - , const CausalConv1dTilingData* tilingData) + GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR initialStateMode, + GM_ADDR numAcceptedTokens, GM_ADDR y, const CausalConv1dTilingData* tilingData) { - // REGISTER_TILING_DEFAULT(CausalConv1dTilingData); - // auto tiling = (__gm__ CausalConv1dTilingData*)tilingGM; - // GET_TILING_DATA(tilingData, tilingGM); tilingData_ = tilingData; + weightCacheValid_ = false; + cachedC0_ = -1; + cachedDimTileSize_ = -1; xGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(x)); weightGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(weight)); @@ -159,9 +102,18 @@ __aicore__ inline void CausalConv1d::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR biasGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(bias)); } convStatesGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(convStates)); - queryStartLocGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(queryStartLoc)); - cacheIndicesGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(cacheIndices)); - hasInitialStateGm.SetGlobalBuffer(reinterpret_cast<__gm__ bool*>(hasInitialState)); + if (tilingData_->inputMode == 0) { + queryStartLocGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(queryStartLoc)); + } + if (tilingData_->hasCacheIndices != 0) { + cacheIndicesGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(cacheIndices)); + } + if (tilingData_->hasInitialStateMode != 0) { + initialStateModeGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(initialStateMode)); + } + if (tilingData_->hasNumAcceptedTokens != 0) { + numAcceptedTokensGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(numAcceptedTokens)); + } yGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(y)); pipe.InitBuffer(inBuf, RING_SLOTS * MAX_BLOCK_DIM * sizeof(T)); @@ -169,114 +121,143 @@ __aicore__ inline void CausalConv1d::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR pipe.InitBuffer(calcBuf, (MAX_WIDTH + 3) * MAX_BLOCK_DIM * sizeof(float)); AllocEvents(); - - // CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] dim=%d, dimTileSize=%d, blocksPerSeq=%d, batch=%d\n", - // tilingData_->dim, tilingData_->dimTileSize, tilingData_->blocksPerSeq, tilingData_->batch); - // CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] hasBias=%d, activationMode=%d, stateLen=%d, inputMode=%d\n", - // tilingData_->hasBias, tilingData_->activationMode, tilingData_->stateLen, tilingData_->inputMode); } template __aicore__ inline void CausalConv1d::AllocEvents() { - tempVToMte2Event_ = GetTPipePtr()->AllocEventID(); - tempMte2ToVEvent_ = GetTPipePtr()->AllocEventID(); - inputMte2ToVEvent_ = GetTPipePtr()->AllocEventID(); + weightBiasMte2ToVEvent_ = GetTPipePtr()->AllocEventID(); + stateMte2ToVEvent_ = GetTPipePtr()->AllocEventID(); + for (int32_t i = 0; i < RING_SLOTS; ++i) { + inputMte2ToVEvent_[i] = GetTPipePtr()->AllocEventID(); + } + inputVToMte2Event_ = GetTPipePtr()->AllocEventID(); outMte3ToVEvent_[0] = GetTPipePtr()->AllocEventID(); outMte3ToVEvent_[1] = GetTPipePtr()->AllocEventID(); outVToMte3Event_[0] = GetTPipePtr()->AllocEventID(); outVToMte3Event_[1] = GetTPipePtr()->AllocEventID(); + stateWritebackMte3ToVEvent_ = GetTPipePtr()->AllocEventID(); + stateWritebackMte3ToMte2Event_ = GetTPipePtr()->AllocEventID(); + specWritebackMte2ToMte3Event_[0] = GetTPipePtr()->AllocEventID(); + specWritebackMte2ToMte3Event_[1] = GetTPipePtr()->AllocEventID(); + specWritebackMte3ToMte2Event_[0] = GetTPipePtr()->AllocEventID(); + specWritebackMte3ToMte2Event_[1] = GetTPipePtr()->AllocEventID(); } template __aicore__ inline void CausalConv1d::ReleaseEvents() { - GetTPipePtr()->ReleaseEventID(tempVToMte2Event_); - GetTPipePtr()->ReleaseEventID(tempMte2ToVEvent_); - GetTPipePtr()->ReleaseEventID(inputMte2ToVEvent_); + GetTPipePtr()->ReleaseEventID(weightBiasMte2ToVEvent_); + GetTPipePtr()->ReleaseEventID(stateMte2ToVEvent_); + for (int32_t i = 0; i < RING_SLOTS; ++i) { + GetTPipePtr()->ReleaseEventID(inputMte2ToVEvent_[i]); + } + GetTPipePtr()->ReleaseEventID(inputVToMte2Event_); GetTPipePtr()->ReleaseEventID(outMte3ToVEvent_[0]); GetTPipePtr()->ReleaseEventID(outMte3ToVEvent_[1]); GetTPipePtr()->ReleaseEventID(outVToMte3Event_[0]); GetTPipePtr()->ReleaseEventID(outVToMte3Event_[1]); + GetTPipePtr()->ReleaseEventID(stateWritebackMte3ToVEvent_); + GetTPipePtr()->ReleaseEventID(stateWritebackMte3ToMte2Event_); + GetTPipePtr()->ReleaseEventID(specWritebackMte2ToMte3Event_[0]); + GetTPipePtr()->ReleaseEventID(specWritebackMte2ToMte3Event_[1]); + GetTPipePtr()->ReleaseEventID(specWritebackMte3ToMte2Event_[0]); + GetTPipePtr()->ReleaseEventID(specWritebackMte3ToMte2Event_[1]); } template -__aicore__ inline void CausalConv1d::LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg) +__aicore__ inline void CausalConv1d::LoadWeightAndBias(int32_t c0, int32_t dimTileSize) { const int32_t dim = tilingData_->dim; - const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC; - (void)dbgSync; + const int32_t width = static_cast(tilingData_->width); + const int32_t jStart = MAX_WIDTH - width; LocalTensor calc = calcBuf.Get(); LocalTensor weightF = calc; LocalTensor biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM]; - LocalTensor tempT = outBuf.Get(); + const bool hasBias = (tilingData_->hasBias != 0); - // CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] c0=%d, dimTileSize=%d\n", c0, dimTileSize); - - for (int32_t j = 0; j < MAX_WIDTH; ++j) { + for (int32_t j = 0; j < width; ++j) { + const int32_t jDst = jStart + j; const int64_t weightOffset = static_cast(j) * dim + c0; - PipeBarrier(); - DataCopy(tempT, weightGm[weightOffset], dimTileSize); - PipeBarrier(); - Cast(weightF[j * MAX_BLOCK_DIM], tempT, RoundMode::CAST_NONE, dimTileSize); - PipeBarrier(); - // if (dbg && CCONV_DBG_DUMP_WEIGHTS) { - // CCONV_PRINTF("[Dump][weightF] j=%d\n", j); - // CCONV_DUMP_TENSOR_IF(true, weightF[j * MAX_BLOCK_DIM], CCONV_DBG_DUMP_SIZE); - // } + + if constexpr (std::is_same::value) { + DataCopy(weightF[jDst * MAX_BLOCK_DIM], weightGm[weightOffset], dimTileSize); + } else { + DataCopy(weightF.ReinterpretCast()[jDst * MAX_BLOCK_DIM * 2 + MAX_BLOCK_DIM], weightGm[weightOffset], dimTileSize); + } } - if (tilingData_->hasBias != 0) { - PipeBarrier(); - DataCopy(tempT, biasGm[c0], dimTileSize); - PipeBarrier(); - Cast(biasF, tempT, RoundMode::CAST_NONE, dimTileSize); - PipeBarrier(); - // if (dbg && CCONV_DBG_DUMP_BIAS) { - // CCONV_PRINTF("[Dump][biasF]\n"); - // CCONV_DUMP_TENSOR_IF(true, biasF, CCONV_DBG_DUMP_SIZE); - // } - } else { - Duplicate(biasF, 0.0f, dimTileSize); - // CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] bias=0 (no bias)\n"); + if (hasBias) { + if constexpr (std::is_same::value) { + DataCopy(biasF, biasGm[c0], dimTileSize); + } else { + DataCopy(biasF.ReinterpretCast()[MAX_BLOCK_DIM], biasGm[c0], dimTileSize); + } + } + + SetFlag(weightBiasMte2ToVEvent_); + WaitFlag(weightBiasMte2ToVEvent_); + + if constexpr (!std::is_same::value) { + for (int32_t j = 0; j < width; ++j) { + const int32_t jDst = jStart + j; + Cast(weightF[jDst * MAX_BLOCK_DIM], weightF.ReinterpretCast()[jDst * MAX_BLOCK_DIM * 2 + MAX_BLOCK_DIM], + RoundMode::CAST_NONE, dimTileSize); + } + if (hasBias) { + Cast(biasF, biasF.ReinterpretCast()[MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize); + } + } + + if (!hasBias) { + Duplicate(biasF, 0.0f, dimTileSize); } - PipeBarrier(); } template -__aicore__ inline void CausalConv1d::InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len, - int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg) +__aicore__ inline void CausalConv1d::InitRing(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset, + int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, + int32_t dim) { const int32_t stateLen = tilingData_->stateLen; + const int32_t width = static_cast(tilingData_->width); + const int32_t ringStart = MAX_WIDTH - width; LocalTensor ring = inBuf.Get(); - PipeBarrier(); if (hasInit) { - for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) { + for (int32_t i = 0; i < (width - 1); ++i) { + const int32_t pos = stateTokenOffset + i; const int64_t stateOffset = static_cast(cacheIdx) * stateLen * dim + - static_cast(i) * dim + c0; - DataCopy(ring[i * MAX_BLOCK_DIM], convStatesGm[stateOffset], dimTileSize); + static_cast(pos) * dim + c0; + DataCopy(ring[(ringStart + i) * MAX_BLOCK_DIM], convStatesGm[stateOffset], dimTileSize); } + SetFlag(stateMte2ToVEvent_); + WaitFlag(stateMte2ToVEvent_); } else { - for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) { - Duplicate(ring[i * MAX_BLOCK_DIM], static_cast(0), dimTileSize); + for (int32_t i = 0; i < (width - 1); ++i) { + Duplicate(ring[(ringStart + i) * MAX_BLOCK_DIM], static_cast(0), dimTileSize); } - + PipeBarrier(); } - PipeBarrier(); if (len > 0) { + const int32_t slot0 = SlotCurr(0); const int64_t xOffset = static_cast(start) * dim + c0; - PipeBarrier(); - DataCopy(ring[SlotCurr(0) * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize); - PipeBarrier(); + DataCopy(ring[slot0 * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize); + SetFlag(inputMte2ToVEvent_[slot0]); + } + + if (len > 1) { + SetFlag(inputVToMte2Event_); } } template __aicore__ inline void CausalConv1d::RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, - int32_t dim, bool dbg) + int32_t dim) { + const int32_t width = static_cast(tilingData_->width); + const int32_t jStart = MAX_WIDTH - width; LocalTensor calc = calcBuf.Get(); LocalTensor weightF = calc; LocalTensor biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM]; @@ -284,78 +265,77 @@ __aicore__ inline void CausalConv1d::RunSeq(int32_t start, int32_t len, int32 LocalTensor tmpF = accF[MAX_BLOCK_DIM]; LocalTensor ring = inBuf.Get(); LocalTensor outT = outBuf.Get(); - const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC; - (void)dbgSync; const bool hasActivation = (tilingData_->activationMode != 0); - const int32_t dbgMaxTokens = CCONV_DBG_MAX_TOKENS; - const int32_t dbgVerboseTokens = CCONV_DBG_VERBOSE_TOKENS; for (int32_t t = 0; t < len; ++t) { - const bool dbgTok = dbg && (t < dbgMaxTokens); - const bool dbgVerbose = dbg && CCONV_DBG_DUMP_RUNSEQ && (t < dbgVerboseTokens); - const bool dbgStep = dbgVerbose && (t == 0); const int32_t slotCurr = SlotCurr(t); - const int32_t slotH1 = SlotHist(t, 1); - const int32_t slotH2 = SlotHist(t, 2); - const int32_t slotH3 = SlotHist(t, 3); - const int32_t slotPref = (t + 1 < len) ? SlotPrefetch(t) : -1; - const int32_t outSlot = t & 1; + + WaitFlag(inputMte2ToVEvent_[slotCurr]); if (t + 1 < len) { - const int64_t xOffset = static_cast(start + t + 1) * dim + c0; - PipeBarrier(); - DataCopy(ring[slotPref * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize); - PipeBarrier(); - + const int32_t slotNext = SlotPrefetch(t); + const int64_t xOffsetNext = static_cast(start + t + 1) * dim + c0; + WaitFlag(inputVToMte2Event_); + DataCopy(ring[slotNext * MAX_BLOCK_DIM], xGm[xOffsetNext], dimTileSize); + SetFlag(inputMte2ToVEvent_[slotNext]); } DataCopy(accF, biasF, dimTileSize); + PipeBarrier(); - - for (int32_t j = 0; j < MAX_WIDTH; ++j) { + for (int32_t j = jStart; j < MAX_WIDTH; ++j) { const int32_t tap = (MAX_WIDTH - 1) - j; const int32_t slot = (tap == 0) ? slotCurr : SlotHist(t, tap); - PipeBarrier(); Cast(tmpF, ring[slot * MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize); - PipeBarrier(); - - PipeBarrier(); +// PipeBarrier(); MulAddDst(accF, tmpF, weightF[j * MAX_BLOCK_DIM], dimTileSize); - PipeBarrier(); } if (hasActivation) { Silu(tmpF, accF, dimTileSize); } - PipeBarrier(); + const int32_t outSlot = t & 1; + LocalTensor outSlotT = outT[outSlot * MAX_BLOCK_DIM]; + if (t >= 2) { + WaitFlag(outMte3ToVEvent_[outSlot]); + } if constexpr (IsSameType::value) { if (hasActivation) { - DataCopy(outT[outSlot * MAX_BLOCK_DIM], tmpF, dimTileSize); + DataCopy(outSlotT, tmpF, dimTileSize); } else { - DataCopy(outT[outSlot * MAX_BLOCK_DIM], accF, dimTileSize); + DataCopy(outSlotT, accF, dimTileSize); } } else { if (hasActivation) { - Cast(outT[outSlot * MAX_BLOCK_DIM], tmpF, RoundMode::CAST_RINT, dimTileSize); + Cast(outSlotT, tmpF, RoundMode::CAST_RINT, dimTileSize); } else { - Cast(outT[outSlot * MAX_BLOCK_DIM], accF, RoundMode::CAST_RINT, dimTileSize); + Cast(outSlotT, accF, RoundMode::CAST_RINT, dimTileSize); } } - PipeBarrier(); + + SetFlag(outVToMte3Event_[outSlot]); const int64_t outOffset = static_cast(start + t) * dim + c0; - PipeBarrier(); - DataCopy(yGm[outOffset], outT[outSlot * MAX_BLOCK_DIM], dimTileSize); - PipeBarrier(); + + WaitFlag(outVToMte3Event_[outSlot]); + DataCopy(yGm[outOffset], outSlotT, dimTileSize); + if (t + 2 < len) { + SetFlag(outMte3ToVEvent_[outSlot]); + } + + if (t + 2 < len) { + SetFlag(inputVToMte2Event_); + } } } template __aicore__ inline void CausalConv1d::WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0, - int32_t dimTileSize, int32_t dim, bool dbg) + int32_t dimTileSize, int32_t dim) { const int32_t stateLen = tilingData_->stateLen; + const int32_t width = static_cast(tilingData_->width); if (len <= 0) { return; } @@ -363,14 +343,95 @@ __aicore__ inline void CausalConv1d::WriteBackState(int32_t cacheIdx, int32_t const int32_t lastT = len - 1; LocalTensor ring = inBuf.Get(); - for (int32_t pos = 0; pos < (MAX_WIDTH - 1); ++pos) { - const int32_t tap = (MAX_WIDTH - 2) - pos; + for (int32_t pos = 0; pos < (width - 1); ++pos) { + const int32_t tap = (width - 2) - pos; const int32_t slot = (tap == 0) ? SlotCurr(lastT) : SlotHist(lastT, tap); const int64_t stateOffset = static_cast(cacheIdx) * stateLen * dim + static_cast(pos) * dim + c0; - PipeBarrier(); DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize); - PipeBarrier(); + } +} + +template +__aicore__ inline void CausalConv1d::WriteBackStateSpec(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset, + int32_t start, int32_t len, int32_t c0, + int32_t dimTileSize, int32_t dim) +{ + const int32_t width = static_cast(tilingData_->width); + const int32_t stateLen = tilingData_->stateLen; + if (len <= 0) { + return; + } + + if (width != 4) { + WriteBackState(cacheIdx, len, c0, dimTileSize, dim); + return; + } + + constexpr int32_t keep = MAX_WIDTH - 2; + const int32_t reqStateLen = keep + len; + if (reqStateLen > stateLen) { + WriteBackState(cacheIdx, len, c0, dimTileSize, dim); + return; + } + + LocalTensor ring = inBuf.Get(); + LocalTensor buf0 = ring[0 * MAX_BLOCK_DIM]; + LocalTensor buf1 = ring[1 * MAX_BLOCK_DIM]; + + if (hasInit) { + const int32_t srcPos0 = stateTokenOffset + 1; + const int32_t srcPos1 = stateTokenOffset + 2; + const int64_t srcOffset0 = static_cast(cacheIdx) * stateLen * dim + static_cast(srcPos0) * dim + c0; + const int64_t srcOffset1 = static_cast(cacheIdx) * stateLen * dim + static_cast(srcPos1) * dim + c0; + DataCopy(buf0, convStatesGm[srcOffset0], dimTileSize); + DataCopy(buf1, convStatesGm[srcOffset1], dimTileSize); + PipeBarrier(); + const int64_t dstOffset0 = static_cast(cacheIdx) * stateLen * dim + static_cast(0) * dim + c0; + const int64_t dstOffset1 = static_cast(cacheIdx) * stateLen * dim + static_cast(1) * dim + c0; + DataCopy(convStatesGm[dstOffset0], buf0, dimTileSize); + DataCopy(convStatesGm[dstOffset1], buf1, dimTileSize); + PipeBarrier(); + } else { + Duplicate(buf0, static_cast(0), dimTileSize); + PipeBarrier(); + const int64_t dstOffset0 = static_cast(cacheIdx) * stateLen * dim + static_cast(0) * dim + c0; + const int64_t dstOffset1 = static_cast(cacheIdx) * stateLen * dim + static_cast(1) * dim + c0; + DataCopy(convStatesGm[dstOffset0], buf0, dimTileSize); + DataCopy(convStatesGm[dstOffset1], buf0, dimTileSize); + PipeBarrier(); + } + + const int64_t xOffset0 = static_cast(start) * dim + c0; + DataCopy(buf0, xGm[xOffset0], dimTileSize); + SetFlag(specWritebackMte2ToMte3Event_[0]); + + for (int32_t t = 0; t < len; ++t) { + const int32_t curr = t & 1; + const int32_t next = curr ^ 1; + LocalTensor currBuf = (curr == 0) ? buf0 : buf1; + LocalTensor nextBuf = (next == 0) ? buf0 : buf1; + + WaitFlag(specWritebackMte2ToMte3Event_[curr]); + + if (t + 1 < len) { + const int64_t xOffsetNext = static_cast(start + t + 1) * dim + c0; + if (t > 0) { + WaitFlag(specWritebackMte3ToMte2Event_[next]); + } + DataCopy(nextBuf, xGm[xOffsetNext], dimTileSize); + SetFlag(specWritebackMte2ToMte3Event_[next]); + } + + const int64_t dstOffset = static_cast(cacheIdx) * stateLen * dim + + static_cast(keep + t) * dim + c0; + DataCopy(convStatesGm[dstOffset], currBuf, dimTileSize); + SetFlag(specWritebackMte3ToMte2Event_[curr]); + } + + WaitFlag(specWritebackMte3ToMte2Event_[0]); + if (len > 1) { + WaitFlag(specWritebackMte3ToMte2Event_[1]); } } @@ -383,11 +444,14 @@ __aicore__ inline void CausalConv1d::Process() const int32_t seqLen = tilingData_->seqLen; const int32_t dimTileSize = static_cast(tilingData_->dimTileSize); const int32_t blocksPerSeq = static_cast(tilingData_->blocksPerSeq); + const int32_t width = static_cast(tilingData_->width); + const bool isSpecDecodingGlobal = + (tilingData_->runMode == 1) && (tilingData_->hasNumAcceptedTokens != 0) && (width == 4); const uint32_t blockIdx = GetBlockIdx(); const uint32_t blockNum = GetBlockNum(); - if (dimTileSize <= 0 || blocksPerSeq <= 0 || dimTileSize > MAX_BLOCK_DIM || blocksPerSeq * dimTileSize != dim) { + if (dimTileSize <= 0 || blocksPerSeq <= 0 || dimTileSize > MAX_BLOCK_DIM || width < 2 || width > MAX_WIDTH) { ReleaseEvents(); return; } @@ -397,9 +461,10 @@ __aicore__ inline void CausalConv1d::Process() const int32_t seq = static_cast(task / blocksPerSeq); const int32_t dimBlockId = static_cast(task % blocksPerSeq); const int32_t c0 = dimBlockId * dimTileSize; - const bool dbg = (seq == CCONV_DBG_SEQ) && (c0 == CCONV_DBG_C0); - - LoadWeightAndBias(c0, dimTileSize, dbg); + if (c0 >= dim) { + continue; + } + const int32_t dimTileSizeActual = (c0 + dimTileSize <= dim) ? dimTileSize : (dim - c0); int32_t start = 0; int32_t len = 0; @@ -408,6 +473,9 @@ __aicore__ inline void CausalConv1d::Process() const int32_t endVal = queryStartLocGm.GetValue(seq + 1); start = startVal; len = endVal - startVal; + } else if (inputMode == 2) { + start = seq; + len = 1; } else { start = seq * seqLen; len = seqLen; @@ -417,20 +485,59 @@ __aicore__ inline void CausalConv1d::Process() continue; } - const int32_t cacheIdx = cacheIndicesGm.GetValue(seq); - if (cacheIdx == tilingData_->padSlotId) { - continue; + int32_t cacheIdx = seq; + if (tilingData_->hasCacheIndices != 0) { + const int64_t cacheIdx64 = cacheIndicesGm.GetValue(seq); + if (cacheIdx64 == tilingData_->padSlotId) { + continue; + } + cacheIdx = static_cast(cacheIdx64); } - const bool hasInit = hasInitialStateGm.GetValue(seq); + const bool hasInit = + (tilingData_->hasInitialStateMode != 0) ? (initialStateModeGm.GetValue(seq) != 0) : false; + int32_t stateTokenOffset = 0; + if (isSpecDecodingGlobal) { + int32_t accepted = static_cast(numAcceptedTokensGm.GetValue(seq)); + stateTokenOffset = accepted - 1; + const int32_t maxOffset = static_cast(tilingData_->stateLen - (width - 1)); + if (stateTokenOffset < 0) { + stateTokenOffset = 0; + } else if (stateTokenOffset > maxOffset) { + stateTokenOffset = maxOffset; + } + } - InitRing(cacheIdx, hasInit, start, len, c0, dimTileSize, dim, dbg); - RunSeq(start, len, c0, dimTileSize, dim, dbg); - WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg); + const bool weightCacheHit = + weightCacheValid_ && (cachedC0_ == c0) && (cachedDimTileSize_ == dimTileSizeActual); + if (!weightCacheHit) { + LoadWeightAndBias(c0, dimTileSizeActual); + weightCacheValid_ = true; + cachedC0_ = c0; + cachedDimTileSize_ = dimTileSizeActual; + } + + InitRing(cacheIdx, hasInit, stateTokenOffset, start, len, c0, dimTileSizeActual, dim); + RunSeq(start, len, c0, dimTileSizeActual, dim); + + SetFlag(stateWritebackMte3ToVEvent_); + WaitFlag(stateWritebackMte3ToVEvent_); + SetFlag(stateWritebackMte3ToMte2Event_); + WaitFlag(stateWritebackMte3ToMte2Event_); + + if (isSpecDecodingGlobal) { + WriteBackStateSpec(cacheIdx, hasInit, stateTokenOffset, start, len, c0, dimTileSizeActual, dim); + } else { + WriteBackState(cacheIdx, len, c0, dimTileSizeActual, dim); + } + + PipeBarrier(); + PipeBarrier(); + PipeBarrier(); } ReleaseEvents(); } } // namespace NsCausalConv1d -#endif // CAUSAL_CONV1D_H \ No newline at end of file +#endif // CAUSAL_CONV1D_H diff --git a/csrc/causal_conv1d/op_kernel/causal_conv1d_common.h b/csrc/causal_conv1d/op_kernel/causal_conv1d_common.h index 39861092..f946d4e7 100644 --- a/csrc/causal_conv1d/op_kernel/causal_conv1d_common.h +++ b/csrc/causal_conv1d/op_kernel/causal_conv1d_common.h @@ -42,4 +42,4 @@ __aicore__ inline int32_t SlotPrefetch(int32_t t) } // namespace NsCausalConv1dCommon -#endif // CAUSAL_CONV1D_COMMON_H \ No newline at end of file +#endif // CAUSAL_CONV1D_COMMON_H diff --git a/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_data.h b/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_data.h new file mode 100644 index 00000000..3f1fe365 --- /dev/null +++ b/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_data.h @@ -0,0 +1,49 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file causal_conv1d_tiling_data.h + * \brief tiling data struct + */ + +#ifndef CAUSAL_CONV1D_TILING_DATA_H_ +#define CAUSAL_CONV1D_TILING_DATA_H_ + +#include + +struct CausalConv1dTilingData { + int64_t dim; + int64_t cuSeqlen; + int64_t seqLen; + int64_t inputMode; + int64_t runMode; + + int64_t width; + + int64_t stateLen; + int64_t numCacheLines; + + int64_t batch; + + int64_t activationMode; + int64_t padSlotId; + + int64_t hasBias; + + int64_t dimTileSize; + int64_t blocksPerSeq; + + int64_t hasNumAcceptedTokens; + + int64_t hasCacheIndices; + int64_t hasInitialStateMode; +}; +#endif // CAUSAL_CONV1D_TILING_DATA_H_ diff --git a/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_key.h b/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_key.h index a456b625..a2d2ec54 100644 --- a/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_key.h +++ b/csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_key.h @@ -31,4 +31,4 @@ ASCENDC_TPL_SEL( ASCENDC_TPL_UINT_SEL( schMode, ASCENDC_TPL_UI_LIST, CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT))); -#endif // __CAUSAL_CONV1D_TILING_KEY_H__ \ No newline at end of file +#endif // __CAUSAL_CONV1D_TILING_KEY_H__ diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index b79065ae..386fb9af 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -633,40 +633,34 @@ npu_copy_and_expand_eagle_inputs( out_new_token_indices, out_hidden_state_mapping}; } -at::Tensor causal_conv1d_fn( - const at::Tensor& mixed_qkv_non_spec_T, - const at::Tensor& conv_weights, - const c10::optional& bias_opt, - c10::string_view activation, +at::Tensor npu_causal_conv1d_custom( + const at::Tensor& x, + const at::Tensor& weight, const at::Tensor& conv_state, - const at::Tensor& has_initial_state, - const at::Tensor& non_spec_state_indices_tensor, - const at::Tensor& non_spec_query_start_loc, - int64_t pad_slot_id) + const c10::optional& bias_opt, + at::IntArrayRef query_start_loc_opt, + at::IntArrayRef cache_indices_opt, + at::IntArrayRef initial_state_mode_opt, + at::IntArrayRef num_accepted_tokens_opt, + int64_t activation_mode, + int64_t pad_slot_id, + int64_t run_mode) { - at::Tensor x=mixed_qkv_non_spec_T; //不需要转置 - at::Tensor weight=conv_weights;//不需要转置 - c10::optional biasOptional =bias_opt; - at::Tensor convStates= conv_state; - at::Tensor queryStartLoc=non_spec_query_start_loc; - at::Tensor cacheIndices=non_spec_state_indices_tensor; - at::Tensor hasInitialState=has_initial_state; - int64_t activationMode=(activation.empty()?0:1); - int64_t padSlotId=pad_slot_id; - - at::Tensor output = at::empty(mixed_qkv_non_spec_T.sizes(), mixed_qkv_non_spec_T.options()); + at::Tensor output = at::empty(x.sizes(), x.options()); EXEC_NPU_CMD(aclnnCausalConv1d, - x, + x, weight, - biasOptional, - convStates, - queryStartLoc, - cacheIndices, - hasInitialState, - activationMode, - padSlotId, + bias_opt, + conv_state, + query_start_loc_opt, + cache_indices_opt, + initial_state_mode_opt, + num_accepted_tokens_opt, + activation_mode, + pad_slot_id, + run_mode, output - ); + ); return output; } @@ -895,18 +889,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)" ); ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs); - // causal_conv1d_fn ops.def( - "causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, " - " Tensor conv_weights, " - " Tensor? bias_opt, " - " str activation, " + "npu_causal_conv1d_custom(Tensor x, " + " Tensor weight, " " Tensor conv_state, " - " Tensor has_initial_state, " - " Tensor non_spec_state_indices_tensor, " - " Tensor non_spec_query_start_loc, " - " int pad_slot_id) -> (Tensor output)"); - ops.impl("causal_conv1d_fn", torch::kPrivateUse1, &vllm_ascend::causal_conv1d_fn); + " Tensor? bias_opt, " + " int[] query_start_loc_opt, " + " int[] cache_indices_opt, " + " int[] initial_state_mode_opt, " + " int[] num_accepted_tokens_opt, " + " int activation_mode, " + " int pad_slot_id, " + " int run_mode" + ") -> (Tensor output)"); + ops.impl("npu_causal_conv1d_custom", torch::kPrivateUse1, &vllm_ascend::npu_causal_conv1d_custom); ops.def( "moe_grouped_matmul(" "Tensor x," diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 1f62ce2c..ed7f7dc8 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -485,19 +485,21 @@ npu_copy_and_expand_eagle_inputs_meta( out_new_token_indices, out_hidden_state_mapping}; } -at::Tensor causal_conv1d_fn_meta( - const at::Tensor& mixed_qkv_non_spec_T, - const at::Tensor& conv_weights, - const c10::optional& bias_opt, - c10::string_view activation, +at::Tensor npu_causal_conv1d_custom_meta( + const at::Tensor& x, + const at::Tensor& weight, const at::Tensor& conv_state, - const at::Tensor& has_initial_state, - const at::Tensor& non_spec_state_indices_tensor, - const at::Tensor& non_spec_query_start_loc, - int64_t pad_slot_id) + const c10::optional& bias_opt, + at::IntArrayRef query_start_loc_opt, + at::IntArrayRef cache_indices_opt, + at::IntArrayRef initial_state_mode_opt, + at::IntArrayRef num_accepted_tokens_opt, + int64_t activation_mode, + int64_t pad_slot_id, + int64_t run_mode) { - at::Tensor output = at::empty_symint(mixed_qkv_non_spec_T.sym_sizes(), mixed_qkv_non_spec_T.options()); + at::Tensor output = at::empty_symint(x.sym_sizes(), x.options()); return output; } @@ -611,7 +613,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // CopyAndExpandEagleInputs ops.impl("npu_copy_and_expand_eagle_inputs", &vllm_ascend::meta::npu_copy_and_expand_eagle_inputs_meta); // causal_conv1d_fn - ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta); + ops.impl("npu_causal_conv1d_custom", &vllm_ascend::meta::npu_causal_conv1d_custom_meta); // moe_grouped_matmul ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta); // Lightning indexer quant diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py index 89b437d8..88fc9591 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py @@ -157,6 +157,11 @@ def causal_conv1d_fn_pytorch( out_ref_tensor = torch.cat(out_ref, dim=0) return out_ref_tensor +def to_int64_tuple(t): + t = t.to(torch.int64) + if t.dim() == 0: + return (t.item(),) + return tuple(t.tolist()) @pytest.mark.parametrize('has_initial_state', [False, True]) @pytest.mark.parametrize('itype', [torch.bfloat16]) @@ -227,16 +232,19 @@ def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias, x_origin=x.transpose(-1, -2) weight_origin=weight.transpose(-1, -2) conv_states_origin=conv_states.transpose(-1, -2) - out = torch.ops._C_ascend.causal_conv1d_fn( + activation_num = 1 if activation else 0 + out = torch.ops._C_ascend.npu_causal_conv1d_custom( x_origin, weight_origin, - bias, - activation=activation, conv_state=conv_states_origin, - has_initial_state=has_initial_state_tensor, - non_spec_state_indices_tensor=cache_indices, - non_spec_query_start_loc=query_start_loc, + bias_opt=bias, + query_start_loc_opt=to_int64_tuple(query_start_loc), + cache_indices_opt=to_int64_tuple(cache_indices), + initial_state_mode_opt=to_int64_tuple(has_initial_state_tensor), + num_accepted_tokens_opt=[], + activation_mode=activation_num, pad_slot_id=PAD_SLOT_ID, + run_mode=0 ).transpose(-1, -2) validate_cmp(out, out_ref, itype) validate_cmp(conv_states, conv_states_ref, itype) diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py index 4bbe5ead..d19d4ade 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -33,6 +33,13 @@ from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch from vllm_ascend.utils import enable_sp, vllm_version_is +def to_int64_tuple(t): + t = t.to(torch.int64) + if t.dim() == 0: + return (t.item(),) + return tuple(t.tolist()) + + class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): def _forward_core( self, @@ -110,16 +117,19 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): if attn_metadata.num_prefills > 0: if mixed_qkv_non_spec is not None: conv_weights_T = conv_weights.transpose(0, 1) - mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn( + activation_num = 1 if self.activation else 0 + mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom( mixed_qkv_non_spec, conv_weights_T, - self.conv1d.bias, - activation=self.activation, conv_state=self_kv_cache[0], - has_initial_state=has_initial_state, - non_spec_state_indices_tensor=non_spec_state_indices_tensor, - non_spec_query_start_loc=non_spec_query_start_loc, + bias_opt=self.conv1d.bias, + query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc), + cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor), + initial_state_mode_opt=to_int64_tuple(has_initial_state), + num_accepted_tokens_opt=[], + activation_mode=activation_num, pad_slot_id=PAD_SLOT_ID, + run_mode=0, ) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update( diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index bbbeb8bb..ff7e0c22 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -32,6 +32,7 @@ from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch +from vllm_ascend.patch.worker.patch_qwen3_5 import to_int64_tuple from vllm_ascend.utils import enable_sp, vllm_version_is @@ -167,16 +168,19 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): if attn_metadata.num_prefills > 0: if mixed_qkv_non_spec is not None: conv_weights_T = conv_weights.transpose(0, 1) - mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn( + activation_num = 1 if self.activation else 0 + mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom( mixed_qkv_non_spec, conv_weights_T, - self.conv1d.bias, - activation=self.activation, conv_state=self_kv_cache[0], - has_initial_state=has_initial_state, - non_spec_state_indices_tensor=non_spec_state_indices_tensor, - non_spec_query_start_loc=non_spec_query_start_loc, + bias_opt=self.conv1d.bias, + query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc), + cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor), + initial_state_mode_opt=to_int64_tuple(has_initial_state), + num_accepted_tokens_opt=[], + activation_mode=activation_num, pad_slot_id=PAD_SLOT_ID, + run_mode=0, ) elif attn_metadata.num_decodes > 0: mixed_qkv_non_spec = causal_conv1d_update(