[Ops][Misc] Refactor and optimize CausalConv1d for Ascend (#7495)
### What this PR does / why we need it?
During the prefill phase of Qwen3-Next and Qwen3.5, the
`torch.ops._C_ascend.causal_conv1d_fn` operator exhibits significant
performance bottlenecks. To address this, we have re-implemented the
optimization using `torch.ops._C_ascend.npu_causal_conv1d_custom`.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
1 accuracy test
```
[2026-03-20 16:44:22,961] [ais_bench] [INFO] Start launch task state board ...
+-----------------------------+-----------+------------+-------------+----------+-------------------------------------------+---------------------+
| Task Name | Process | Progress | Time Cost | Status | Log Path | Extend Parameters |
+=============================+===========+============+=============+==========+===========================================+=====================+
| vllm-api-general-chat/gsm8k | 2918978 | NA | 0:00:01 | finish | logs/eval/vllm-api-general-chat/gsm8k.out | None |
+-----------------------------+-----------+------------+-------------+----------+-------------------------------------------+---------------------+
[2026-03-20 16:44:34,284] [ais_bench] [INFO] Evaluation tasks completed.
[2026-03-20 16:44:34,287] [ais_bench] [INFO] Summarizing evaluation results...
dataset version metric mode vllm-api-general-chat
--------- --------- -------- ------ -----------------------
gsm8k 271d0b accuracy gen 96.21
```
2 ut modify test
`pytest -sv
/home/c30006096/vllm-ascend/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py::test_ascend_causal_conv1d`
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: wenba0 <3054239545@qq.com>
Signed-off-by: jiaojiao <56385650+wenba0@users.noreply.github.com>
This commit is contained in:
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
add_ops_compile_options(
|
add_ops_compile_options(
|
||||||
OP_NAME CausalConv1d
|
OP_NAME CausalConv1d
|
||||||
OPTIONS --cce-auto-sync=off
|
OPTIONS --cce-auto-sync=on
|
||||||
-Wno-deprecated-declarations
|
-Wno-deprecated-declarations
|
||||||
-Werror
|
-Werror
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,19 +42,28 @@ public:
|
|||||||
.FormatList({ge::FORMAT_ND})
|
.FormatList({ge::FORMAT_ND})
|
||||||
.AutoContiguous();
|
.AutoContiguous();
|
||||||
this->Input("queryStartLoc")
|
this->Input("queryStartLoc")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(OPTIONAL)
|
||||||
.DataTypeList({ge::DT_INT32})
|
.DataTypeList({ge::DT_INT64})
|
||||||
.FormatList({ge::FORMAT_ND})
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.ValueDepend(OPTIONAL)
|
||||||
.AutoContiguous();
|
.AutoContiguous();
|
||||||
this->Input("cacheIndices")
|
this->Input("cacheIndices")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(OPTIONAL)
|
||||||
.DataTypeList({ge::DT_INT32})
|
.DataTypeList({ge::DT_INT64})
|
||||||
.FormatList({ge::FORMAT_ND})
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.ValueDepend(OPTIONAL)
|
||||||
.AutoContiguous();
|
.AutoContiguous();
|
||||||
this->Input("hasInitialState")
|
this->Input("initialStateMode")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(OPTIONAL)
|
||||||
.DataTypeList({ge::DT_BOOL})
|
.DataTypeList({ge::DT_INT64})
|
||||||
.FormatList({ge::FORMAT_ND})
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.ValueDepend(OPTIONAL)
|
||||||
|
.AutoContiguous();
|
||||||
|
this->Input("numAcceptedTokens")
|
||||||
|
.ParamType(OPTIONAL)
|
||||||
|
.DataTypeList({ge::DT_INT64})
|
||||||
|
.FormatList({ge::FORMAT_ND})
|
||||||
|
.ValueDepend(OPTIONAL)
|
||||||
.AutoContiguous();
|
.AutoContiguous();
|
||||||
|
|
||||||
this->Output("y")
|
this->Output("y")
|
||||||
@@ -65,6 +74,7 @@ public:
|
|||||||
|
|
||||||
this->Attr("activationMode").AttrType(OPTIONAL).Int(0);
|
this->Attr("activationMode").AttrType(OPTIONAL).Int(0);
|
||||||
this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1);
|
this->Attr("padSlotId").AttrType(OPTIONAL).Int(-1);
|
||||||
|
this->Attr("runMode").AttrType(OPTIONAL).Int(0);
|
||||||
|
|
||||||
OpAICoreConfig aicoreConfig;
|
OpAICoreConfig aicoreConfig;
|
||||||
aicoreConfig.DynamicCompileStaticFlag(true)
|
aicoreConfig.DynamicCompileStaticFlag(true)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
* \brief
|
* \brief
|
||||||
*/
|
*/
|
||||||
#include "register/op_impl_registry.h"
|
#include "register/op_impl_registry.h"
|
||||||
#include "error_log.h"
|
#include "log/log.h"
|
||||||
|
|
||||||
using namespace ge;
|
using namespace ge;
|
||||||
|
|
||||||
@@ -23,25 +23,17 @@ static constexpr int64_t IDX_0 = 0;
|
|||||||
|
|
||||||
static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context)
|
static ge::graphStatus InferShapeCausalConv1d(gert::InferShapeContext* context)
|
||||||
{
|
{
|
||||||
// OPS_LOG_D(context->GetNodeName(), "Begin to do InferShapeCausalConv1d");
|
OP_LOGD(context->GetNodeName(), "Begin to do InferShapeCausalConv1d");
|
||||||
|
|
||||||
// get input shapes
|
|
||||||
const gert::Shape* xShape = context->GetInputShape(IDX_0);
|
const gert::Shape* xShape = context->GetInputShape(IDX_0);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
|
||||||
|
|
||||||
// get output shapes
|
|
||||||
gert::Shape* yShape = context->GetOutputShape(IDX_0);
|
gert::Shape* yShape = context->GetOutputShape(IDX_0);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
|
||||||
|
|
||||||
// 填充输出shape大小
|
*yShape = *xShape;
|
||||||
auto xShapeSize = xShape->GetDimNum();
|
|
||||||
yShape->SetDimNum(xShapeSize);
|
|
||||||
for (size_t i = 0; i < xShapeSize; i++) {
|
|
||||||
int64_t dim = xShape->GetDim(i);
|
|
||||||
yShape->SetDim(i, dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
// OPS_LOG_D(context->GetNodeName(), "End to do InferShapeCausalConv1d");
|
OP_LOGD(context->GetNodeName(), "End to do InferShapeCausalConv1d");
|
||||||
return GRAPH_SUCCESS;
|
return GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,12 +14,12 @@
|
|||||||
* \brief
|
* \brief
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// #include "error_log.h"
|
//#include "log/log.h"
|
||||||
#include "log/ops_log.h"
|
#include "error_log.h"
|
||||||
#include "../tiling_base/tiling_templates_registry.h"
|
#include "../tiling_base/tiling_templates_registry.h"
|
||||||
#include "../tiling_base/tiling_util.h"
|
#include "../tiling_base/tiling_util.h"
|
||||||
#include "math_util.h"
|
#include "math_util.h"
|
||||||
#include "causal_conv1d_tiling.h"
|
#include "../op_kernel/causal_conv1d_tiling_data.h"
|
||||||
#include "../op_kernel/causal_conv1d_tiling_key.h"
|
#include "../op_kernel/causal_conv1d_tiling_key.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
@@ -35,12 +35,17 @@ constexpr uint32_t BIAS_INDEX = 2;
|
|||||||
constexpr uint32_t CONV_STATES_INDEX = 3;
|
constexpr uint32_t CONV_STATES_INDEX = 3;
|
||||||
constexpr uint32_t QUERY_START_LOC_INDEX = 4;
|
constexpr uint32_t QUERY_START_LOC_INDEX = 4;
|
||||||
constexpr uint32_t CACHE_INDICES_INDEX = 5;
|
constexpr uint32_t CACHE_INDICES_INDEX = 5;
|
||||||
constexpr uint32_t HAS_INITIAL_STATE_INDEX = 6;
|
constexpr uint32_t INITIAL_STATE_MODE_INDEX = 6;
|
||||||
|
constexpr uint32_t NUM_ACCEPTED_TOKENS_INDEX = 7;
|
||||||
|
|
||||||
constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0;
|
constexpr int32_t ATTR_ACTIVATION_MODE_INDEX = 0;
|
||||||
constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1;
|
constexpr int32_t ATTR_PAD_SLOT_ID_INDEX = 1;
|
||||||
|
constexpr int32_t ATTR_RUN_MODE_INDEX = 2;
|
||||||
|
|
||||||
|
struct CausalConv1dCompileInfo {
|
||||||
|
uint64_t ubSize = 0;
|
||||||
|
uint32_t coreNum = 0;
|
||||||
|
};
|
||||||
|
|
||||||
struct DimTileChoice {
|
struct DimTileChoice {
|
||||||
int64_t dimTileSize = 0;
|
int64_t dimTileSize = 0;
|
||||||
@@ -48,27 +53,46 @@ struct DimTileChoice {
|
|||||||
int64_t gridSize = 0;
|
int64_t gridSize = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static inline int64_t CeilDivInt64(int64_t x, int64_t y)
|
||||||
|
{
|
||||||
|
return (x + y - 1) / y;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool FitsInInt32(int64_t v)
|
||||||
|
{
|
||||||
|
return v >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
|
||||||
|
v <= static_cast<int64_t>(std::numeric_limits<int32_t>::max());
|
||||||
|
}
|
||||||
|
|
||||||
static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum)
|
static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int64_t batch, int64_t dim, uint32_t coreNum)
|
||||||
{
|
{
|
||||||
|
const int64_t candidates[] = {4096, 2048, 1024, 512, 384, 192};
|
||||||
|
|
||||||
const int64_t candidates[] = {4096, 2048, 1024, 512,384};
|
auto ChooseOnce = [&](bool requireExactDiv) -> DimTileChoice {
|
||||||
DimTileChoice bestOver;
|
DimTileChoice bestOver;
|
||||||
int64_t bestOverGap = std::numeric_limits<int64_t>::max();
|
int64_t bestOverGap = std::numeric_limits<int64_t>::max();
|
||||||
DimTileChoice bestUnder;
|
DimTileChoice bestUnder;
|
||||||
|
|
||||||
for (int64_t dimTileSize : candidates) {
|
for (int64_t dimTileSize : candidates) {
|
||||||
if (dim % dimTileSize != 0) {
|
if (dimTileSize <= 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const int64_t blocksPerSeq = dim / dimTileSize;
|
if (requireExactDiv && (dim % dimTileSize != 0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int64_t blocksPerSeq = requireExactDiv ? (dim / dimTileSize) : CeilDivInt64(dim, dimTileSize);
|
||||||
const int64_t gridSize = batch * blocksPerSeq;
|
const int64_t gridSize = batch * blocksPerSeq;
|
||||||
if (gridSize <= 0) {
|
if (gridSize <= 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
OP_LOGD(context,
|
||||||
|
"DimTile candidate[%s]: dimTileSize[%ld], blocksPerSeq[%ld], gridSize[%ld], coreNum[%u].",
|
||||||
|
requireExactDiv ? "exact" : "tail",
|
||||||
|
dimTileSize, blocksPerSeq, gridSize, coreNum);
|
||||||
if (gridSize >= static_cast<int64_t>(coreNum)) {
|
if (gridSize >= static_cast<int64_t>(coreNum)) {
|
||||||
const int64_t gap = gridSize - static_cast<int64_t>(coreNum);
|
const int64_t gap = gridSize - static_cast<int64_t>(coreNum);
|
||||||
if (gap < bestOverGap) {
|
if (gap < bestOverGap) {
|
||||||
|
// bestOver = {dimTileSize, blocksPerSeq, gridSize};
|
||||||
bestOver.dimTileSize = dimTileSize;
|
bestOver.dimTileSize = dimTileSize;
|
||||||
bestOver.blocksPerSeq = blocksPerSeq;
|
bestOver.blocksPerSeq = blocksPerSeq;
|
||||||
bestOver.gridSize = gridSize;
|
bestOver.gridSize = gridSize;
|
||||||
@@ -81,30 +105,28 @@ static inline DimTileChoice ChooseDimTileSize(gert::TilingContext* context, int6
|
|||||||
bestUnder.gridSize = gridSize;
|
bestUnder.gridSize = gridSize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DimTileChoice result = (bestOver.dimTileSize != 0) ? bestOver : bestUnder;
|
return (bestOver.dimTileSize != 0) ? bestOver : bestUnder;
|
||||||
|
};
|
||||||
|
|
||||||
|
DimTileChoice result = ChooseOnce(true /*requireExactDiv*/);
|
||||||
|
if (result.dimTileSize == 0) {
|
||||||
|
result = ChooseOnce(false /*requireExactDiv*/);
|
||||||
|
}
|
||||||
|
OP_LOGD(context,
|
||||||
|
"DimTile chosen: dimTileSize[%ld], blocksPerSeq[%ld], gridSize[%ld].",
|
||||||
|
result.dimTileSize, result.blocksPerSeq, result.gridSize);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum)
|
static ge::graphStatus GetPlatformInfo(gert::TilingContext* context, uint64_t& ubSize, uint32_t& coreNum)
|
||||||
{
|
{
|
||||||
auto compileInfoPtr = context->GetCompileInfo<CausalConv1dCompileInfo>();
|
|
||||||
if (compileInfoPtr != nullptr && compileInfoPtr->coreNum != 0 && compileInfoPtr->ubSize != 0) {
|
|
||||||
ubSize = compileInfoPtr->ubSize;
|
|
||||||
coreNum = compileInfoPtr->coreNum;
|
|
||||||
return ge::GRAPH_SUCCESS;
|
|
||||||
}
|
|
||||||
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
fe::PlatFormInfos* platformInfoPtr = context->GetPlatformInfo();
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
||||||
coreNum = ascendcPlatform.GetCoreNumAiv();
|
coreNum = ascendcPlatform.GetCoreNumAiv();
|
||||||
if(coreNum == 0) {
|
OP_CHECK_IF(coreNum == 0, OP_LOGE(context, "coreNum is 0"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
|
||||||
if(ubSize == 0) {
|
OP_CHECK_IF(ubSize == 0, OP_LOGE(context, "ubSize is 0"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +138,8 @@ static ge::graphStatus GetWorkspaceSize(gert::TilingContext* context)
|
|||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId)
|
static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activationMode, int64_t& padSlotId,
|
||||||
|
int64_t& runMode)
|
||||||
{
|
{
|
||||||
auto attrs = context->GetAttrs();
|
auto attrs = context->GetAttrs();
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
|
||||||
@@ -124,17 +147,26 @@ static ge::graphStatus GetAttrsInfo(gert::TilingContext* context, int64_t& activ
|
|||||||
const int64_t* activationModePtr = attrs->GetAttrPointer<int64_t>(ATTR_ACTIVATION_MODE_INDEX);
|
const int64_t* activationModePtr = attrs->GetAttrPointer<int64_t>(ATTR_ACTIVATION_MODE_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr);
|
OP_CHECK_NULL_WITH_CONTEXT(context, activationModePtr);
|
||||||
activationMode = *activationModePtr;
|
activationMode = *activationModePtr;
|
||||||
if(activationMode != 0 && activationMode != 1){
|
OP_CHECK_IF(
|
||||||
return ge::GRAPH_FAILED;
|
activationMode != 0 && activationMode != 1, OP_LOGE(context, "activationMode only supports 0/1"),
|
||||||
}
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
const int64_t* padSlotIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_PAD_SLOT_ID_INDEX);
|
const int64_t* padSlotIdPtr = attrs->GetAttrPointer<int64_t>(ATTR_PAD_SLOT_ID_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr);
|
OP_CHECK_NULL_WITH_CONTEXT(context, padSlotIdPtr);
|
||||||
padSlotId = *padSlotIdPtr;
|
padSlotId = *padSlotIdPtr;
|
||||||
|
|
||||||
|
const int64_t* runModePtr = attrs->GetAttrPointer<int64_t>(ATTR_RUN_MODE_INDEX);
|
||||||
|
runMode = (runModePtr == nullptr) ? 0 : *runModePtr;
|
||||||
|
OP_CHECK_IF(runMode != 0 && runMode != 1, OP_LOGE(context, "runMode only supports 0/1"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling)
|
static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalConv1dTilingData& tiling)
|
||||||
{
|
{
|
||||||
|
const bool isDecodeMode = (tiling.runMode == 1);
|
||||||
|
|
||||||
auto xShapePtr = context->GetInputShape(X_INDEX);
|
auto xShapePtr = context->GetInputShape(X_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr);
|
OP_CHECK_NULL_WITH_CONTEXT(context, xShapePtr);
|
||||||
auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape());
|
auto xShape = EnsureNotScalar(xShapePtr->GetStorageShape());
|
||||||
@@ -146,149 +178,361 @@ static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalCon
|
|||||||
int64_t inputMode = 0;
|
int64_t inputMode = 0;
|
||||||
|
|
||||||
if (xShape.GetDimNum() == 2) {
|
if (xShape.GetDimNum() == 2) {
|
||||||
|
if (isDecodeMode) {
|
||||||
|
inputMode = 2;
|
||||||
|
batch = xShape.GetDim(0);
|
||||||
|
dim = xShape.GetDim(1);
|
||||||
|
seqLen = 1;
|
||||||
|
cuSeqlen = batch;
|
||||||
|
OP_CHECK_IF(batch <= 0 || dim <= 0,
|
||||||
|
OP_LOGE(context, "invalid x shape for 2D decode mode"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
} else {
|
||||||
inputMode = 0;
|
inputMode = 0;
|
||||||
cuSeqlen = xShape.GetDim(0);
|
cuSeqlen = xShape.GetDim(0);
|
||||||
dim = xShape.GetDim(1);
|
dim = xShape.GetDim(1);
|
||||||
seqLen = 0;
|
seqLen = 0;
|
||||||
if(dim <= 0 || cuSeqlen < 0){
|
OP_CHECK_IF(dim <= 0 || cuSeqlen < 0,
|
||||||
return ge::GRAPH_FAILED;
|
OP_LOGE(context, "invalid x shape for 2D varlen mode"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (xShape.GetDimNum() == 3) {
|
} else if (xShape.GetDimNum() == 3) {
|
||||||
inputMode = 1;
|
inputMode = 1;
|
||||||
batch = xShape.GetDim(0);
|
batch = xShape.GetDim(0);
|
||||||
seqLen = xShape.GetDim(1);
|
seqLen = xShape.GetDim(1);
|
||||||
dim = xShape.GetDim(2);
|
dim = xShape.GetDim(2);
|
||||||
cuSeqlen = batch * seqLen;
|
cuSeqlen = batch * seqLen;
|
||||||
if(batch <= 0 || dim <= 0 || seqLen <= 0){
|
OP_CHECK_IF(batch <= 0 || dim <= 0 || seqLen <= 0, OP_LOGE(context, "invalid x shape for 3D batch mode"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
OP_LOGE(context, "x must be 2D (cu_seqlen, dim) or 3D (batch, seqlen, dim)");
|
||||||
return ge::GRAPH_FAILED;
|
return ge::GRAPH_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto wShapePtr = context->GetInputShape(WEIGHT_INDEX);
|
auto wShapePtr = context->GetInputShape(WEIGHT_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr);
|
OP_CHECK_NULL_WITH_CONTEXT(context, wShapePtr);
|
||||||
auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape());
|
auto wShape = EnsureNotScalar(wShapePtr->GetStorageShape());
|
||||||
if(wShape.GetDimNum() != 2){
|
OP_CHECK_IF(wShape.GetDimNum() != 2, OP_LOGE(context, "weight must be 2D: (width, dim)"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
const int64_t width = wShape.GetDim(0);
|
const int64_t width = wShape.GetDim(0);
|
||||||
const int64_t wDim = wShape.GetDim(1);
|
const int64_t wDim = wShape.GetDim(1);
|
||||||
if(wDim != dim){
|
OP_CHECK_IF(wDim != dim, OP_LOGE(context, "weight.shape[1] must equal dim"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;
|
OP_CHECK_IF(width < 2 || width > 4,
|
||||||
}
|
OP_LOGE(context, "Only support width in [2,4] now, actually is %ld.", width),
|
||||||
if(width != 4){
|
return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;
|
OP_CHECK_IF(dim % 16 != 0,
|
||||||
}
|
OP_LOGE(context, "dim must be a multiple of 16 for fp16/bf16 alignment, actually is %ld.", dim),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX);
|
auto sShapePtr = context->GetInputShape(CONV_STATES_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr);
|
OP_CHECK_NULL_WITH_CONTEXT(context, sShapePtr);
|
||||||
auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape());
|
auto sShape = EnsureNotScalar(sShapePtr->GetStorageShape());
|
||||||
if(sShape.GetDimNum() != 3){
|
OP_CHECK_IF(
|
||||||
return ge::GRAPH_FAILED;
|
sShape.GetDimNum() != 3, OP_LOGE(context, "convStates must be 3D: (num_cache_lines, state_len, dim)"),
|
||||||
}
|
return ge::GRAPH_FAILED);
|
||||||
const int64_t numCacheLines = sShape.GetDim(0);
|
const int64_t numCacheLines = sShape.GetDim(0);
|
||||||
const int64_t stateLen = sShape.GetDim(1);
|
const int64_t stateLen = sShape.GetDim(1);
|
||||||
const int64_t sDim = sShape.GetDim(2);
|
const int64_t sDim = sShape.GetDim(2);
|
||||||
if(numCacheLines <= 0){
|
OP_CHECK_IF(numCacheLines <= 0, OP_LOGE(context, "convStates.shape[0] (num_cache_lines) must be > 0"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;}
|
OP_CHECK_IF(sDim != dim, OP_LOGE(context, "convStates.shape[2] must equal dim"), return ge::GRAPH_FAILED);
|
||||||
if(sDim != dim){
|
OP_CHECK_IF(stateLen < (width - 1), OP_LOGE(context, "convStates.shape[1] must be >= width-1"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
if(stateLen < (width - 1)){
|
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
|
|
||||||
auto qslShapePtr = context->GetInputShape(QUERY_START_LOC_INDEX);
|
auto qslShapePtr = context->GetOptionalInputShape(QUERY_START_LOC_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, qslShapePtr);
|
const gert::CompileTimeTensorDesc* qslDesc = context->GetOptionalInputDesc(QUERY_START_LOC_INDEX);
|
||||||
auto qslShape = EnsureNotScalar(qslShapePtr->GetStorageShape());
|
bool qslAbsent = true;
|
||||||
if(qslShape.GetDimNum() != 1){
|
int64_t qslSize = 0;
|
||||||
return ge::GRAPH_FAILED;}
|
if (qslShapePtr != nullptr) {
|
||||||
const int64_t qslSize = qslShape.GetDim(0);
|
const auto qslStorageShape = qslShapePtr->GetStorageShape();
|
||||||
if(qslSize < 1){
|
const int64_t qslDimNum = qslStorageShape.GetDimNum();
|
||||||
return ge::GRAPH_FAILED;}
|
qslAbsent = (qslDimNum == 0) || (qslDimNum == 1 && qslStorageShape.GetDim(0) <= 0);
|
||||||
|
|
||||||
|
if (!qslAbsent) {
|
||||||
|
auto qslShape = EnsureNotScalar(qslStorageShape);
|
||||||
|
OP_CHECK_IF(qslShape.GetDimNum() != 1, OP_LOGE(context, "queryStartLoc must be 1D"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
qslSize = qslShape.GetDim(0);
|
||||||
|
OP_CHECK_IF(qslSize < 1, OP_LOGE(context, "queryStartLoc.size must be >= 1"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
|
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc);
|
||||||
|
OP_CHECK_IF(qslDesc->GetDataType() != ge::DT_INT64,
|
||||||
|
OP_LOGE(context, "queryStartLoc dtype must be int64"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (qslAbsent) {
|
||||||
|
OP_CHECK_IF(inputMode == 0,
|
||||||
|
OP_LOGE(context, "queryStartLoc is required in 2D varlen mode (inputMode=0)"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
qslSize = batch + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_CHECK_IF(cuSeqlen > static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
|
||||||
|
OP_LOGE(context, "cuSeqlen is too large for int32 indexing, got %ld", cuSeqlen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
|
const int64_t* qslData = nullptr;
|
||||||
|
if (!qslAbsent) {
|
||||||
|
const gert::Tensor* qslTensor = context->GetOptionalInputTensor(QUERY_START_LOC_INDEX);
|
||||||
|
qslData = (qslTensor != nullptr) ? qslTensor->GetData<int64_t>() : nullptr;
|
||||||
|
if (qslData != nullptr) {
|
||||||
|
OP_CHECK_IF(qslData[0] != 0, OP_LOGE(context, "queryStartLoc[0] must be 0"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(qslData[qslSize - 1] != cuSeqlen,
|
||||||
|
OP_LOGE(context, "queryStartLoc[last] must equal cuSeqlen, got %ld vs %ld",
|
||||||
|
qslData[qslSize - 1], cuSeqlen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
for (int64_t i = 0; i + 1 < qslSize; ++i) {
|
||||||
|
const int64_t cur = qslData[i];
|
||||||
|
const int64_t nxt = qslData[i + 1];
|
||||||
|
OP_CHECK_IF(cur < 0 || cur > cuSeqlen,
|
||||||
|
OP_LOGE(context, "queryStartLoc[%ld] out of range: %ld (cuSeqlen=%ld)", i, cur, cuSeqlen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(nxt < 0 || nxt > cuSeqlen,
|
||||||
|
OP_LOGE(context, "queryStartLoc[%ld] out of range: %ld (cuSeqlen=%ld)", i + 1, nxt, cuSeqlen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(nxt < cur,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"queryStartLoc must be non-decreasing, got queryStartLoc[%ld]=%ld queryStartLoc[%ld]=%ld",
|
||||||
|
i, cur, i + 1, nxt),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!qslAbsent && isDecodeMode && inputMode == 2) {
|
||||||
|
const int64_t batchFromQsl = qslSize - 1;
|
||||||
|
if (batchFromQsl != batch) {
|
||||||
|
inputMode = 0;
|
||||||
|
cuSeqlen = xShape.GetDim(0);
|
||||||
|
batch = batchFromQsl;
|
||||||
|
seqLen = 0;
|
||||||
|
OP_CHECK_IF(dim <= 0 || cuSeqlen < 0 || batch < 0,
|
||||||
|
OP_LOGE(context, "invalid x/queryStartLoc shapes for 2D varlen decode mode"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (inputMode == 0) {
|
if (inputMode == 0) {
|
||||||
batch = qslSize - 1;
|
batch = qslSize - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!qslAbsent && (inputMode == 1 || inputMode == 2)) {
|
||||||
|
OP_CHECK_IF(qslSize != batch + 1, OP_LOGE(context, "queryStartLoc.size must equal batch + 1"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isDecodeMode) {
|
||||||
|
const int64_t decodeSeqLen = (inputMode == 1) ? seqLen : 1;
|
||||||
|
OP_CHECK_IF(decodeSeqLen < 1,
|
||||||
|
OP_LOGE(context, "decode mode requires seqlen >= 1, actual is %ld", decodeSeqLen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
|
tiling.hasCacheIndices = 0;
|
||||||
|
bool ciAbsent = true;
|
||||||
|
auto ciShapePtr = context->GetOptionalInputShape(CACHE_INDICES_INDEX);
|
||||||
|
if (ciShapePtr != nullptr) {
|
||||||
|
const auto ciStorageShape = ciShapePtr->GetStorageShape();
|
||||||
|
const int64_t ciDimNum = ciStorageShape.GetDimNum();
|
||||||
|
ciAbsent = (ciDimNum == 0) || (ciDimNum == 1 && ciStorageShape.GetDim(0) <= 0);
|
||||||
|
if (!ciAbsent) {
|
||||||
|
auto ciShape = EnsureNotScalar(ciStorageShape);
|
||||||
|
OP_CHECK_IF(ciShape.GetDimNum() != 1, OP_LOGE(context, "cacheIndices must be 1D"), return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(ciShape.GetDim(0) != batch, OP_LOGE(context, "cacheIndices.size must equal batch"), return ge::GRAPH_FAILED);
|
||||||
|
tiling.hasCacheIndices = 1;
|
||||||
|
|
||||||
|
const gert::Tensor* ciTensor = context->GetOptionalInputTensor(CACHE_INDICES_INDEX);
|
||||||
|
const int64_t* ciData = (ciTensor != nullptr) ? ciTensor->GetData<int64_t>() : nullptr;
|
||||||
|
if (ciData != nullptr) {
|
||||||
|
for (int64_t i = 0; i < batch; ++i) {
|
||||||
|
const int64_t v = ciData[i];
|
||||||
|
if (v == tiling.padSlotId) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
OP_CHECK_IF(!FitsInInt32(v),
|
||||||
|
OP_LOGE(context, "cacheIndices[%ld]=%ld does not fit int32", i, v),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(v < 0 || v >= numCacheLines,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"cacheIndices[%ld]=%ld out of range [0, num_cache_lines=%ld), padSlotId=%ld",
|
||||||
|
i, v, numCacheLines, tiling.padSlotId),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ciAbsent) {
|
||||||
|
OP_CHECK_IF(numCacheLines < batch,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"cacheIndices is absent, requires convStates.shape[0] (num_cache_lines) >= batch for identity mapping, got num_cache_lines=%ld batch=%ld",
|
||||||
|
numCacheLines, batch),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
|
tiling.hasInitialStateMode = 0;
|
||||||
|
auto ismShapePtr = context->GetOptionalInputShape(INITIAL_STATE_MODE_INDEX);
|
||||||
|
if (ismShapePtr != nullptr) {
|
||||||
|
const auto ismStorageShape = ismShapePtr->GetStorageShape();
|
||||||
|
const int64_t ismDimNum = ismStorageShape.GetDimNum();
|
||||||
|
const bool ismAbsent = (ismDimNum == 0) || (ismDimNum == 1 && ismStorageShape.GetDim(0) <= 0);
|
||||||
|
if (!ismAbsent) {
|
||||||
|
auto ismShape = EnsureNotScalar(ismStorageShape);
|
||||||
|
OP_CHECK_IF(ismShape.GetDimNum() != 1, OP_LOGE(context, "initialStateMode must be 1D"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(ismShape.GetDim(0) != batch, OP_LOGE(context, "initialStateMode.size must equal batch"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
tiling.hasInitialStateMode = 1;
|
||||||
|
|
||||||
|
const gert::Tensor* ismTensor = context->GetOptionalInputTensor(INITIAL_STATE_MODE_INDEX);
|
||||||
|
const int64_t* ismData = (ismTensor != nullptr) ? ismTensor->GetData<int64_t>() : nullptr;
|
||||||
|
if (ismData != nullptr) {
|
||||||
|
for (int64_t i = 0; i < batch; ++i) {
|
||||||
|
const int64_t v = ismData[i];
|
||||||
|
OP_CHECK_IF(v != 0 && v != 1,
|
||||||
|
OP_LOGE(context, "initialStateMode[%ld]=%ld is invalid (only supports 0/1)", i, v),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tiling.hasNumAcceptedTokens = 0;
|
||||||
|
auto natShapePtr = context->GetOptionalInputShape(NUM_ACCEPTED_TOKENS_INDEX);
|
||||||
|
if (natShapePtr != nullptr) {
|
||||||
|
const auto natStorageShape = natShapePtr->GetStorageShape();
|
||||||
|
const int64_t natDimNum = natStorageShape.GetDimNum();
|
||||||
|
const bool natAbsent = (natDimNum == 0) || (natDimNum == 1 && natStorageShape.GetDim(0) <= 0);
|
||||||
|
if (!natAbsent) {
|
||||||
|
OP_CHECK_IF(!isDecodeMode,
|
||||||
|
OP_LOGE(context, "numAcceptedTokens is only supported in runMode=1 (decode/update)"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
auto natShape = EnsureNotScalar(natStorageShape);
|
||||||
|
OP_CHECK_IF(natShape.GetDimNum() != 1, OP_LOGE(context, "numAcceptedTokens must be 1D"), return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(natShape.GetDim(0) != batch, OP_LOGE(context, "numAcceptedTokens.size must equal batch"), return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
if (inputMode == 1) {
|
if (inputMode == 1) {
|
||||||
if(qslSize != batch + 1){
|
const int64_t reqStateLen = (width - 1) + (seqLen - 1);
|
||||||
return ge::GRAPH_FAILED;
|
OP_CHECK_IF(
|
||||||
|
stateLen < reqStateLen,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"spec decode requires stateLen >= (width-1) + (seqlen-1), got stateLen=%ld req=%ld",
|
||||||
|
stateLen, reqStateLen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
|
const gert::Tensor* natTensor = context->GetOptionalInputTensor(NUM_ACCEPTED_TOKENS_INDEX);
|
||||||
|
const int64_t* natData = (natTensor != nullptr) ? natTensor->GetData<int64_t>() : nullptr;
|
||||||
|
if (natData != nullptr) {
|
||||||
|
for (int64_t i = 0; i < batch; ++i) {
|
||||||
|
const int64_t a = natData[i];
|
||||||
|
OP_CHECK_IF(a < 0,
|
||||||
|
OP_LOGE(context, "numAcceptedTokens[%ld]=%ld is invalid (must be >= 0)", i, a),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(!FitsInInt32(a),
|
||||||
|
OP_LOGE(context, "numAcceptedTokens[%ld]=%ld does not fit int32", i, a),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
|
if (inputMode == 2) {
|
||||||
|
OP_CHECK_IF(a > 1,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"numAcceptedTokens[%ld]=%ld exceeds decode 2D token count (1)", i, a),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
} else if (inputMode == 1) {
|
||||||
|
OP_CHECK_IF(a > seqLen,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"numAcceptedTokens[%ld]=%ld exceeds seqlen=%ld in 3D update", i, a, seqLen),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
} else if (inputMode == 0) {
|
||||||
|
if (qslData != nullptr) {
|
||||||
|
const int64_t lenI = qslData[i + 1] - qslData[i];
|
||||||
|
OP_CHECK_IF(a > lenI,
|
||||||
|
OP_LOGE(context,
|
||||||
|
"numAcceptedTokens[%ld]=%ld exceeds varlen segment length=%ld",
|
||||||
|
i, a, lenI),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX);
|
tiling.hasNumAcceptedTokens = 1;
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr);
|
}
|
||||||
auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape());
|
}
|
||||||
if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;}
|
|
||||||
if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;}
|
|
||||||
|
|
||||||
auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX);
|
tiling.hasBias = 0;
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr);
|
|
||||||
auto hisShape = EnsureNotScalar(hisShapePtr->GetStorageShape());
|
|
||||||
if(hisShape.GetDimNum() != 1){
|
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
if(hisShape.GetDim(0) != batch){
|
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
|
|
||||||
tiling.set_hasBias(0);
|
|
||||||
auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX);
|
auto biasShapePtr = context->GetOptionalInputShape(BIAS_INDEX);
|
||||||
if (biasShapePtr != nullptr && biasShapePtr->GetStorageShape().GetDimNum() != 0) {
|
if (biasShapePtr != nullptr) {
|
||||||
auto biasShape = EnsureNotScalar(biasShapePtr->GetStorageShape());
|
const auto biasStorageShape = biasShapePtr->GetStorageShape();
|
||||||
if(biasShape.GetDimNum() != 1){
|
const int64_t biasDimNum = biasStorageShape.GetDimNum();
|
||||||
return ge::GRAPH_FAILED;}
|
const bool biasAbsent = (biasDimNum == 0) || (biasDimNum == 1 && biasStorageShape.GetDim(0) <= 0);
|
||||||
if(biasShape.GetDim(0) != dim){
|
if (!biasAbsent) {
|
||||||
return ge::GRAPH_FAILED;}
|
auto biasShape = EnsureNotScalar(biasStorageShape);
|
||||||
tiling.set_hasBias(1);
|
OP_CHECK_IF(biasShape.GetDimNum() != 1, OP_LOGE(context, "bias must be 1D: (dim,)"), return ge::GRAPH_FAILED);
|
||||||
|
OP_CHECK_IF(biasShape.GetDim(0) != dim, OP_LOGE(context, "bias.size must equal dim"), return ge::GRAPH_FAILED);
|
||||||
|
tiling.hasBias = 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::set<ge::DataType> supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16};
|
const std::set<ge::DataType> supportedXDtype = {ge::DT_BF16, ge::DT_FLOAT16};
|
||||||
auto xDesc = context->GetInputDesc(X_INDEX);
|
auto xDesc = context->GetInputDesc(X_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, xDesc);
|
OP_CHECK_NULL_WITH_CONTEXT(context, xDesc);
|
||||||
const ge::DataType xDtype = xDesc->GetDataType();
|
const ge::DataType xDtype = xDesc->GetDataType();
|
||||||
if(supportedXDtype.count(xDtype) == 0){
|
OP_CHECK_IF(supportedXDtype.count(xDtype) == 0, OP_LOGE(context, "x dtype only supports bf16/fp16"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
|
|
||||||
auto wDesc = context->GetInputDesc(WEIGHT_INDEX);
|
auto wDesc = context->GetInputDesc(WEIGHT_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, wDesc);
|
OP_CHECK_NULL_WITH_CONTEXT(context, wDesc);
|
||||||
if(wDesc->GetDataType() != xDtype){
|
OP_CHECK_IF(wDesc->GetDataType() != xDtype, OP_LOGE(context, "weight dtype must equal x dtype"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
|
|
||||||
if (tiling.get_hasBias() == 1) {
|
if (tiling.hasBias == 1) {
|
||||||
auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX);
|
auto biasDesc = context->GetOptionalInputDesc(BIAS_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc);
|
OP_CHECK_NULL_WITH_CONTEXT(context, biasDesc);
|
||||||
if(biasDesc->GetDataType() != xDtype){
|
OP_CHECK_IF(biasDesc->GetDataType() != xDtype, OP_LOGE(context, "bias dtype must equal x dtype"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sDesc = context->GetInputDesc(CONV_STATES_INDEX);
|
auto sDesc = context->GetInputDesc(CONV_STATES_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, sDesc);
|
OP_CHECK_NULL_WITH_CONTEXT(context, sDesc);
|
||||||
if(sDesc->GetDataType() != xDtype){
|
OP_CHECK_IF(sDesc->GetDataType() != xDtype, OP_LOGE(context, "convStates dtype must equal x dtype"), return ge::GRAPH_FAILED);
|
||||||
return ge::GRAPH_FAILED;}
|
|
||||||
|
|
||||||
auto qslDesc = context->GetInputDesc(QUERY_START_LOC_INDEX);
|
if (!qslAbsent) {
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc);
|
auto qslDesc2 = context->GetOptionalInputDesc(QUERY_START_LOC_INDEX);
|
||||||
if(qslDesc->GetDataType() != ge::DT_INT32){
|
OP_CHECK_NULL_WITH_CONTEXT(context, qslDesc2);
|
||||||
return ge::GRAPH_FAILED;}
|
OP_CHECK_IF(qslDesc2->GetDataType() != ge::DT_INT64, OP_LOGE(context, "queryStartLoc dtype must be int64"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
auto ciDesc = context->GetInputDesc(CACHE_INDICES_INDEX);
|
if (tiling.hasCacheIndices == 1) {
|
||||||
|
auto ciDesc = context->GetOptionalInputDesc(CACHE_INDICES_INDEX);
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc);
|
OP_CHECK_NULL_WITH_CONTEXT(context, ciDesc);
|
||||||
if(ciDesc->GetDataType() != ge::DT_INT32){
|
OP_CHECK_IF(ciDesc->GetDataType() != ge::DT_INT64, OP_LOGE(context, "cacheIndices dtype must be int64"),
|
||||||
return ge::GRAPH_FAILED;}
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
auto hisDesc = context->GetInputDesc(HAS_INITIAL_STATE_INDEX);
|
if (tiling.hasInitialStateMode == 1) {
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, hisDesc);
|
auto ismDesc = context->GetOptionalInputDesc(INITIAL_STATE_MODE_INDEX);
|
||||||
if(hisDesc->GetDataType() != ge::DT_BOOL){
|
OP_CHECK_NULL_WITH_CONTEXT(context, ismDesc);
|
||||||
return ge::GRAPH_FAILED;}
|
OP_CHECK_IF(ismDesc->GetDataType() != ge::DT_INT64,
|
||||||
|
OP_LOGE(context, "initialStateMode dtype must be int64"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
tiling.set_dim(dim);
|
if (tiling.hasNumAcceptedTokens == 1) {
|
||||||
tiling.set_cuSeqlen(cuSeqlen);
|
OP_CHECK_IF(width != 4,
|
||||||
tiling.set_seqLen(seqLen);
|
OP_LOGE(context, "numAcceptedTokens is only supported for width=4 currently"),
|
||||||
tiling.set_inputMode(inputMode);
|
return ge::GRAPH_FAILED);
|
||||||
tiling.set_width(width);
|
auto natDesc = context->GetOptionalInputDesc(NUM_ACCEPTED_TOKENS_INDEX);
|
||||||
tiling.set_stateLen(stateLen);
|
OP_CHECK_NULL_WITH_CONTEXT(context, natDesc);
|
||||||
tiling.set_numCacheLines(numCacheLines);
|
OP_CHECK_IF(natDesc->GetDataType() != ge::DT_INT64, OP_LOGE(context, "numAcceptedTokens dtype must be int64"),
|
||||||
tiling.set_batch(batch);
|
return ge::GRAPH_FAILED);
|
||||||
|
}
|
||||||
|
|
||||||
|
tiling.dim = dim;
|
||||||
|
tiling.cuSeqlen = cuSeqlen;
|
||||||
|
tiling.seqLen = seqLen;
|
||||||
|
tiling.inputMode = inputMode;
|
||||||
|
tiling.width = width;
|
||||||
|
tiling.stateLen = stateLen;
|
||||||
|
tiling.numCacheLines = numCacheLines;
|
||||||
|
tiling.batch = batch;
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,66 +540,57 @@ static ge::graphStatus CausalConv1dTilingFunc(gert::TilingContext* context)
|
|||||||
{
|
{
|
||||||
uint64_t ubSize;
|
uint64_t ubSize;
|
||||||
uint32_t coreNum;
|
uint32_t coreNum;
|
||||||
if( GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS){
|
OP_CHECK_IF(
|
||||||
return ge::GRAPH_FAILED;
|
GetPlatformInfo(context, ubSize, coreNum) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetPlatformInfo error"),
|
||||||
}
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
if(GetWorkspaceSize(context) != ge::GRAPH_SUCCESS){
|
OP_CHECK_IF(
|
||||||
return ge::GRAPH_FAILED;
|
GetWorkspaceSize(context) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetWorkspaceSize error"),
|
||||||
}
|
return ge::GRAPH_FAILED);
|
||||||
CausalConv1dTilingData tilingData;
|
|
||||||
|
|
||||||
int64_t activationMode = 0;
|
CausalConv1dTilingData* tiling = context->GetTilingData<CausalConv1dTilingData>();
|
||||||
int64_t padSlotId = -1;
|
OP_CHECK_NULL_WITH_CONTEXT(context, tiling);
|
||||||
if(GetAttrsInfo(context, activationMode, padSlotId) != ge::GRAPH_SUCCESS){
|
OP_CHECK_IF(
|
||||||
return ge::GRAPH_FAILED;
|
memset_s(tiling, sizeof(CausalConv1dTilingData), 0, sizeof(CausalConv1dTilingData)) != EOK,
|
||||||
}
|
OP_LOGE(context, "set tiling data error"), return ge::GRAPH_FAILED);
|
||||||
tilingData.set_activationMode(activationMode);
|
|
||||||
tilingData.set_padSlotId(padSlotId);
|
|
||||||
|
|
||||||
if( GetShapeDtypeInfo(context, tilingData) != ge::GRAPH_SUCCESS){
|
OP_CHECK_IF(
|
||||||
return ge::GRAPH_FAILED;
|
GetAttrsInfo(context, tiling->activationMode, tiling->padSlotId, tiling->runMode) != ge::GRAPH_SUCCESS,
|
||||||
}
|
OP_LOGE(context, "GetAttrsInfo error"), return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
|
OP_CHECK_IF(
|
||||||
|
GetShapeDtypeInfo(context, *tiling) != ge::GRAPH_SUCCESS, OP_LOGE(context, "GetShapeDtypeInfo error"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
|
const int64_t dim = tiling->dim;
|
||||||
|
const int64_t batch = tiling->batch;
|
||||||
|
OP_CHECK_IF(dim <= 0 || batch <= 0, OP_LOGE(context, "dim/batch must be positive"), return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
const int64_t dim = tilingData.get_dim();
|
|
||||||
const int64_t batch = tilingData.get_batch();
|
|
||||||
if(dim <= 0 || batch <= 0){
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum);
|
const DimTileChoice choice = ChooseDimTileSize(context, batch, dim, coreNum);
|
||||||
|
OP_CHECK_IF(choice.dimTileSize <= 0 || choice.blocksPerSeq <= 0 || choice.gridSize <= 0,
|
||||||
|
OP_LOGE(context, "invalid dim_tile_size selection"),
|
||||||
|
return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
const uint32_t blockDim = (choice.gridSize < static_cast<int64_t>(coreNum))
|
const uint32_t blockDim = (choice.gridSize < static_cast<int64_t>(coreNum))
|
||||||
? static_cast<uint32_t>(choice.gridSize)
|
? static_cast<uint32_t>(choice.gridSize)
|
||||||
: coreNum;
|
: coreNum;
|
||||||
|
|
||||||
|
OP_LOGD(context,
|
||||||
|
"Tiling result: batch[%ld], dim[%ld], dimTileSize[%ld], blocksPerSeq[%ld], gridSize[%ld], blockDim[%u], coreNum[%u].",
|
||||||
|
batch, dim, choice.dimTileSize, choice.blocksPerSeq, choice.gridSize, blockDim, coreNum);
|
||||||
|
|
||||||
context->SetBlockDim(blockDim);
|
context->SetBlockDim(blockDim);
|
||||||
tilingData.set_dimTileSize(choice.dimTileSize);
|
tiling->dimTileSize = choice.dimTileSize;
|
||||||
tilingData.set_blocksPerSeq(choice.blocksPerSeq);
|
tiling->blocksPerSeq = choice.blocksPerSeq;
|
||||||
|
|
||||||
const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT);
|
const uint64_t tilingKey = GET_TPL_TILING_KEY(CAUSAL_CONV1D_TPL_SCH_MODE_DEFAULT);
|
||||||
context->SetTilingKey(tilingKey);
|
context->SetTilingKey(tilingKey);
|
||||||
|
|
||||||
tilingData.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
|
|
||||||
context->GetRawTilingData()->SetDataSize(tilingData.GetDataSize());
|
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context)
|
static ge::graphStatus TilingParseForCausalConv1d(gert::TilingParseContext* context)
|
||||||
{
|
{
|
||||||
auto platformInfoPtr = context->GetPlatformInfo();
|
OP_LOGD(context, "Enter TilingParseForCausalConv1d.");
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, platformInfoPtr);
|
|
||||||
auto compileInfoPtr = context->GetCompiledInfo<CausalConv1dCompileInfo>();
|
|
||||||
OP_CHECK_NULL_WITH_CONTEXT(context, compileInfoPtr);
|
|
||||||
|
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
|
|
||||||
compileInfoPtr->coreNum = static_cast<uint32_t>(ascendcPlatform.GetCoreNumAiv());
|
|
||||||
if(compileInfoPtr->coreNum == 0){
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, compileInfoPtr->ubSize);
|
|
||||||
if(compileInfoPtr->ubSize == 0){
|
|
||||||
return ge::GRAPH_FAILED;
|
|
||||||
}
|
|
||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
/**
|
|
||||||
* This program is free software, you can redistribute it and/or modify it.
|
|
||||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
|
||||||
* This file is a part of the CANN Open Software.
|
|
||||||
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
|
||||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
|
||||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
|
||||||
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \file causal_conv1d_tiling_data.h
|
|
||||||
* \brief
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
|
||||||
#define ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
// #include "register/tilingdata_base.h"
|
|
||||||
// #include "tiling/tiling_api.h"
|
|
||||||
#include "register/tilingdata_base.h"
|
|
||||||
#include "error_log.h"
|
|
||||||
#include "register/op_impl_registry.h"
|
|
||||||
#include "tiling/platform/platform_ascendc.h"
|
|
||||||
#include "platform/platform_infos_def.h"
|
|
||||||
namespace optiling {
|
|
||||||
|
|
||||||
BEGIN_TILING_DATA_DEF(CausalConv1dTilingData)
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, dim);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, cuSeqlen);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, seqLen);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, inputMode);
|
|
||||||
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, width);
|
|
||||||
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, stateLen);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, numCacheLines);
|
|
||||||
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, batch);
|
|
||||||
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, activationMode);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, padSlotId);
|
|
||||||
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, hasBias);
|
|
||||||
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, dimTileSize);
|
|
||||||
TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq);
|
|
||||||
END_TILING_DATA_DEF;
|
|
||||||
struct CausalConv1dCompileInfo {
|
|
||||||
uint64_t ubSize = 0;
|
|
||||||
uint32_t coreNum = 0;
|
|
||||||
};
|
|
||||||
REGISTER_TILING_DATA_CLASS(CausalConv1d, CausalConv1dTilingData)
|
|
||||||
|
|
||||||
} // namespace optiling
|
|
||||||
|
|
||||||
#endif // ASCEND_OPS_CAUSAL_CONV1D_TILING_DATA_H
|
|
||||||
@@ -18,13 +18,16 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
// NOTE:
|
||||||
|
// Dtype is provided via AscendC compile macros (e.g. DTYPE_X / ORIG_DTYPE_X), so tiling key does not need to carry dtype.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
__aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
__aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
||||||
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
|
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR initialStateMode,
|
||||||
GM_ADDR y, const NsCausalConv1d::CausalConv1dTilingData* tilingData)
|
GM_ADDR numAcceptedTokens, GM_ADDR y, const CausalConv1dTilingData* tilingData)
|
||||||
{
|
{
|
||||||
NsCausalConv1d::CausalConv1d<T> op;
|
NsCausalConv1d::CausalConv1d<T> op;
|
||||||
op.Init(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, tilingData);
|
op.Init(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, tilingData);
|
||||||
op.Process();
|
op.Process();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,27 +35,24 @@ __aicore__ inline void RunCausalConv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias,
|
|||||||
|
|
||||||
template <uint32_t schMode>
|
template <uint32_t schMode>
|
||||||
__global__ __aicore__ void causal_conv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
__global__ __aicore__ void causal_conv1d(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
||||||
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
|
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR initialStateMode,
|
||||||
GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
|
GM_ADDR numAcceptedTokens, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
|
||||||
{
|
{
|
||||||
REGISTER_TILING_DEFAULT( NsCausalConv1d::CausalConv1dTilingData);
|
REGISTER_TILING_DEFAULT(CausalConv1dTilingData);
|
||||||
// GET_TILING_DATA_WITH_STRUCT( NsCausalConv1d::CausalConv1dTilingData, tilingData, tiling);
|
GET_TILING_DATA_WITH_STRUCT(CausalConv1dTilingData, tilingData, tiling);
|
||||||
GET_TILING_DATA(tilingData, tiling);
|
|
||||||
#if defined(ORIG_DTYPE_X)
|
#if defined(ORIG_DTYPE_X)
|
||||||
#if (ORIG_DTYPE_X == DT_FLOAT16)
|
#if (ORIG_DTYPE_X == DT_FLOAT16)
|
||||||
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData);
|
||||||
#elif (ORIG_DTYPE_X == DT_BF16)
|
#elif (ORIG_DTYPE_X == DT_BF16)
|
||||||
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData);
|
||||||
#elif (ORIG_DTYPE_X == DT_FLOAT)
|
|
||||||
RunCausalConv1d<float>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
#if (DTYPE_X == DT_FLOAT16)
|
#if (DTYPE_X == DT_FLOAT16)
|
||||||
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
RunCausalConv1d<half>(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData);
|
||||||
#elif (DTYPE_X == DT_BF16)
|
#elif (DTYPE_X == DT_BF16)
|
||||||
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
RunCausalConv1d<bfloat16_t>(x, weight, bias, convStates, queryStartLoc, cacheIndices, initialStateMode, numAcceptedTokens, y, &tilingData);
|
||||||
#elif (DTYPE_X == DT_FLOAT)
|
|
||||||
RunCausalConv1d<float>(x, weight, bias, convStates, queryStartLoc, cacheIndices, hasInitialState, y, &tilingData);
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,90 +12,23 @@
|
|||||||
/*!
|
/*!
|
||||||
* \file causal_conv1d.h
|
* \file causal_conv1d.h
|
||||||
* \brief CausalConv1D (prefill/extend) AscendC kernel implementation.
|
* \brief CausalConv1D (prefill/extend) AscendC kernel implementation.
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef CAUSAL_CONV1D_H
|
#ifndef CAUSAL_CONV1D_H
|
||||||
#define CAUSAL_CONV1D_H
|
#define CAUSAL_CONV1D_H
|
||||||
|
|
||||||
#include "kernel_operator.h"
|
#include "kernel_operator.h"
|
||||||
// #include "kernel_tiling/kernel_tiling.h"
|
#include "kernel_tiling/kernel_tiling.h"
|
||||||
|
#include "causal_conv1d_tiling_data.h"
|
||||||
#include "causal_conv1d_tiling_key.h"
|
#include "causal_conv1d_tiling_key.h"
|
||||||
#include "causal_conv1d_common.h"
|
#include "causal_conv1d_common.h"
|
||||||
|
|
||||||
// #define ENABLE_CAUSAL_CONV1D_DEBUG
|
|
||||||
|
|
||||||
// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG
|
|
||||||
// #define CCONV_PRINTF(fmt, ...) printf(fmt, ##__VA_ARGS__)
|
|
||||||
// #else
|
|
||||||
// #define CCONV_PRINTF(fmt, ...)
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
// #define CCONV_PRINT_IF(cond, fmt, ...) \
|
|
||||||
// do { \
|
|
||||||
// if (cond) { \
|
|
||||||
// CCONV_PRINTF(fmt, ##__VA_ARGS__); \
|
|
||||||
// } \
|
|
||||||
// } while (0)
|
|
||||||
|
|
||||||
// #ifdef ENABLE_CAUSAL_CONV1D_DEBUG
|
|
||||||
|
|
||||||
// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \
|
|
||||||
// do { \
|
|
||||||
// if (cond) { \
|
|
||||||
// DumpTensor(tensor, __LINE__, size); \
|
|
||||||
// } \
|
|
||||||
// } while (0)
|
|
||||||
// #else
|
|
||||||
constexpr int32_t CCONV_DBG_SEQ = -1;
|
|
||||||
constexpr int32_t CCONV_DBG_C0 = -1;
|
|
||||||
constexpr int32_t CCONV_DBG_MAX_TOKENS = 0;
|
|
||||||
constexpr int32_t CCONV_DBG_VERBOSE_TOKENS = 0;
|
|
||||||
constexpr int32_t CCONV_DBG_DUMP_SIZE = 0;
|
|
||||||
constexpr bool CCONV_DBG_PRINT_SYNC = false;
|
|
||||||
constexpr bool CCONV_DBG_DUMP_WEIGHTS = false;
|
|
||||||
constexpr bool CCONV_DBG_DUMP_BIAS = false;
|
|
||||||
constexpr bool CCONV_DBG_DUMP_INIT_RING = false;
|
|
||||||
constexpr bool CCONV_DBG_DUMP_RUNSEQ = false;
|
|
||||||
constexpr bool CCONV_DBG_DUMP_PREFETCH = false;
|
|
||||||
constexpr bool CCONV_DBG_DUMP_STATE = false;
|
|
||||||
|
|
||||||
// #define CCONV_DUMP_TENSOR_IF(cond, tensor, size) \
|
|
||||||
// do { \
|
|
||||||
// } while (0)
|
|
||||||
// #endif
|
|
||||||
using namespace AscendC;
|
|
||||||
namespace NsCausalConv1d {
|
namespace NsCausalConv1d {
|
||||||
|
|
||||||
|
using namespace AscendC;
|
||||||
using namespace NsCausalConv1dCommon;
|
using namespace NsCausalConv1dCommon;
|
||||||
|
|
||||||
#ifndef CAUSAL_CONV1D_TILING_DATA_H_
|
|
||||||
#define CAUSAL_CONV1D_TILING_DATA_H_
|
|
||||||
|
|
||||||
struct CausalConv1dTilingData {
|
|
||||||
int64_t dim;
|
|
||||||
int64_t cuSeqlen;
|
|
||||||
int64_t seqLen;
|
|
||||||
int64_t inputMode;
|
|
||||||
|
|
||||||
int64_t width;
|
|
||||||
|
|
||||||
int64_t stateLen;
|
|
||||||
int64_t numCacheLines;
|
|
||||||
|
|
||||||
int64_t batch;
|
|
||||||
|
|
||||||
// attrs
|
|
||||||
int64_t activationMode; // 0: none, 1: silu/swish
|
|
||||||
int64_t padSlotId; // default -1
|
|
||||||
|
|
||||||
// optional inputs
|
|
||||||
int64_t hasBias; // 0/1
|
|
||||||
|
|
||||||
// Channel-wise tiling
|
|
||||||
int64_t dimTileSize;
|
|
||||||
int64_t blocksPerSeq;
|
|
||||||
};
|
|
||||||
#endif // CAUSAL_CONV1D_TILING_DATA_H_
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CausalConv1d
|
class CausalConv1d
|
||||||
{
|
{
|
||||||
@@ -103,18 +36,19 @@ public:
|
|||||||
__aicore__ inline CausalConv1d() = default;
|
__aicore__ inline CausalConv1d() = default;
|
||||||
|
|
||||||
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, GM_ADDR queryStartLoc,
|
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates, GM_ADDR queryStartLoc,
|
||||||
GM_ADDR cacheIndices, GM_ADDR hasInitialState, GM_ADDR y
|
GM_ADDR cacheIndices, GM_ADDR initialStateMode, GM_ADDR numAcceptedTokens, GM_ADDR y,
|
||||||
,
|
|
||||||
const CausalConv1dTilingData* tilingData);
|
const CausalConv1dTilingData* tilingData);
|
||||||
__aicore__ inline void Process();
|
__aicore__ inline void Process();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
__aicore__ inline void LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg);
|
__aicore__ inline void LoadWeightAndBias(int32_t c0, int32_t dimTileSize);
|
||||||
__aicore__ inline void InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len,
|
__aicore__ inline void InitRing(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset, int32_t start, int32_t len,
|
||||||
int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg);
|
int32_t c0, int32_t dimTileSize, int32_t dim);
|
||||||
__aicore__ inline void RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg);
|
__aicore__ inline void RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim);
|
||||||
__aicore__ inline void WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
|
__aicore__ inline void WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0, int32_t dimTileSize, int32_t dim);
|
||||||
int32_t dimTileSize, int32_t dim, bool dbg);
|
__aicore__ inline void WriteBackStateSpec(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset,
|
||||||
|
int32_t start, int32_t len, int32_t c0, int32_t dimTileSize,
|
||||||
|
int32_t dim);
|
||||||
__aicore__ inline void AllocEvents();
|
__aicore__ inline void AllocEvents();
|
||||||
__aicore__ inline void ReleaseEvents();
|
__aicore__ inline void ReleaseEvents();
|
||||||
|
|
||||||
@@ -124,34 +58,43 @@ private:
|
|||||||
TBuf<QuePosition::VECOUT> outBuf;
|
TBuf<QuePosition::VECOUT> outBuf;
|
||||||
TBuf<QuePosition::VECCALC> calcBuf;
|
TBuf<QuePosition::VECCALC> calcBuf;
|
||||||
|
|
||||||
TEventID tempVToMte2Event_;
|
TEventID weightBiasMte2ToVEvent_;
|
||||||
TEventID tempMte2ToVEvent_;
|
TEventID stateMte2ToVEvent_;
|
||||||
TEventID inputMte2ToVEvent_;
|
TEventID inputMte2ToVEvent_[RING_SLOTS];
|
||||||
|
TEventID inputVToMte2Event_;
|
||||||
TEventID outMte3ToVEvent_[2];
|
TEventID outMte3ToVEvent_[2];
|
||||||
TEventID outVToMte3Event_[2];
|
TEventID outVToMte3Event_[2];
|
||||||
|
TEventID stateWritebackMte3ToVEvent_;
|
||||||
|
TEventID stateWritebackMte3ToMte2Event_;
|
||||||
|
TEventID specWritebackMte2ToMte3Event_[2];
|
||||||
|
TEventID specWritebackMte3ToMte2Event_[2];
|
||||||
|
|
||||||
GlobalTensor<T> xGm;
|
GlobalTensor<T> xGm;
|
||||||
GlobalTensor<T> weightGm;
|
GlobalTensor<T> weightGm;
|
||||||
GlobalTensor<T> biasGm;
|
GlobalTensor<T> biasGm;
|
||||||
GlobalTensor<T> convStatesGm;
|
GlobalTensor<T> convStatesGm;
|
||||||
GlobalTensor<int32_t> queryStartLocGm;
|
GlobalTensor<int64_t> queryStartLocGm;
|
||||||
GlobalTensor<int32_t> cacheIndicesGm;
|
GlobalTensor<int64_t> cacheIndicesGm;
|
||||||
GlobalTensor<bool> hasInitialStateGm;
|
GlobalTensor<int64_t> initialStateModeGm;
|
||||||
|
GlobalTensor<int64_t> numAcceptedTokensGm;
|
||||||
GlobalTensor<T> yGm;
|
GlobalTensor<T> yGm;
|
||||||
|
|
||||||
const CausalConv1dTilingData* tilingData_ {nullptr};
|
const CausalConv1dTilingData* tilingData_ {nullptr};
|
||||||
|
|
||||||
|
bool weightCacheValid_ {false};
|
||||||
|
int32_t cachedC0_ {-1};
|
||||||
|
int32_t cachedDimTileSize_ {-1};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
__aicore__ inline void CausalConv1d<T>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR convStates,
|
||||||
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR hasInitialState,
|
GM_ADDR queryStartLoc, GM_ADDR cacheIndices, GM_ADDR initialStateMode,
|
||||||
GM_ADDR y
|
GM_ADDR numAcceptedTokens, GM_ADDR y, const CausalConv1dTilingData* tilingData)
|
||||||
, const CausalConv1dTilingData* tilingData)
|
|
||||||
{
|
{
|
||||||
// REGISTER_TILING_DEFAULT(CausalConv1dTilingData);
|
|
||||||
// auto tiling = (__gm__ CausalConv1dTilingData*)tilingGM;
|
|
||||||
// GET_TILING_DATA(tilingData, tilingGM);
|
|
||||||
tilingData_ = tilingData;
|
tilingData_ = tilingData;
|
||||||
|
weightCacheValid_ = false;
|
||||||
|
cachedC0_ = -1;
|
||||||
|
cachedDimTileSize_ = -1;
|
||||||
|
|
||||||
xGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(x));
|
xGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(x));
|
||||||
weightGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(weight));
|
weightGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(weight));
|
||||||
@@ -159,9 +102,18 @@ __aicore__ inline void CausalConv1d<T>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR
|
|||||||
biasGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(bias));
|
biasGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(bias));
|
||||||
}
|
}
|
||||||
convStatesGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(convStates));
|
convStatesGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(convStates));
|
||||||
queryStartLocGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(queryStartLoc));
|
if (tilingData_->inputMode == 0) {
|
||||||
cacheIndicesGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(cacheIndices));
|
queryStartLocGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(queryStartLoc));
|
||||||
hasInitialStateGm.SetGlobalBuffer(reinterpret_cast<__gm__ bool*>(hasInitialState));
|
}
|
||||||
|
if (tilingData_->hasCacheIndices != 0) {
|
||||||
|
cacheIndicesGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(cacheIndices));
|
||||||
|
}
|
||||||
|
if (tilingData_->hasInitialStateMode != 0) {
|
||||||
|
initialStateModeGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(initialStateMode));
|
||||||
|
}
|
||||||
|
if (tilingData_->hasNumAcceptedTokens != 0) {
|
||||||
|
numAcceptedTokensGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t*>(numAcceptedTokens));
|
||||||
|
}
|
||||||
yGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(y));
|
yGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(y));
|
||||||
|
|
||||||
pipe.InitBuffer(inBuf, RING_SLOTS * MAX_BLOCK_DIM * sizeof(T));
|
pipe.InitBuffer(inBuf, RING_SLOTS * MAX_BLOCK_DIM * sizeof(T));
|
||||||
@@ -169,114 +121,143 @@ __aicore__ inline void CausalConv1d<T>::Init(GM_ADDR x, GM_ADDR weight, GM_ADDR
|
|||||||
pipe.InitBuffer(calcBuf, (MAX_WIDTH + 3) * MAX_BLOCK_DIM * sizeof(float));
|
pipe.InitBuffer(calcBuf, (MAX_WIDTH + 3) * MAX_BLOCK_DIM * sizeof(float));
|
||||||
|
|
||||||
AllocEvents();
|
AllocEvents();
|
||||||
|
|
||||||
// CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] dim=%d, dimTileSize=%d, blocksPerSeq=%d, batch=%d\n",
|
|
||||||
// tilingData_->dim, tilingData_->dimTileSize, tilingData_->blocksPerSeq, tilingData_->batch);
|
|
||||||
// CCONV_PRINT_IF(GetBlockIdx() == 0U, "[Init] hasBias=%d, activationMode=%d, stateLen=%d, inputMode=%d\n",
|
|
||||||
// tilingData_->hasBias, tilingData_->activationMode, tilingData_->stateLen, tilingData_->inputMode);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::AllocEvents()
|
__aicore__ inline void CausalConv1d<T>::AllocEvents()
|
||||||
{
|
{
|
||||||
tempVToMte2Event_ = GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>();
|
weightBiasMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
||||||
tempMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
stateMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
||||||
inputMte2ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
for (int32_t i = 0; i < RING_SLOTS; ++i) {
|
||||||
|
inputMte2ToVEvent_[i] = GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>();
|
||||||
|
}
|
||||||
|
inputVToMte2Event_ = GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>();
|
||||||
outMte3ToVEvent_[0] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
outMte3ToVEvent_[0] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
||||||
outMte3ToVEvent_[1] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
outMte3ToVEvent_[1] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
||||||
outVToMte3Event_[0] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
|
outVToMte3Event_[0] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
|
||||||
outVToMte3Event_[1] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
|
outVToMte3Event_[1] = GetTPipePtr()->AllocEventID<HardEvent::V_MTE3>();
|
||||||
|
stateWritebackMte3ToVEvent_ = GetTPipePtr()->AllocEventID<HardEvent::MTE3_V>();
|
||||||
|
stateWritebackMte3ToMte2Event_ = GetTPipePtr()->AllocEventID<HardEvent::MTE3_MTE2>();
|
||||||
|
specWritebackMte2ToMte3Event_[0] = GetTPipePtr()->AllocEventID<HardEvent::MTE2_MTE3>();
|
||||||
|
specWritebackMte2ToMte3Event_[1] = GetTPipePtr()->AllocEventID<HardEvent::MTE2_MTE3>();
|
||||||
|
specWritebackMte3ToMte2Event_[0] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_MTE2>();
|
||||||
|
specWritebackMte3ToMte2Event_[1] = GetTPipePtr()->AllocEventID<HardEvent::MTE3_MTE2>();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::ReleaseEvents()
|
__aicore__ inline void CausalConv1d<T>::ReleaseEvents()
|
||||||
{
|
{
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(tempVToMte2Event_);
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(weightBiasMte2ToVEvent_);
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(tempMte2ToVEvent_);
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(stateMte2ToVEvent_);
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(inputMte2ToVEvent_);
|
for (int32_t i = 0; i < RING_SLOTS; ++i) {
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(inputMte2ToVEvent_[i]);
|
||||||
|
}
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(inputVToMte2Event_);
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[0]);
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[0]);
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[1]);
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(outMte3ToVEvent_[1]);
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[0]);
|
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[0]);
|
||||||
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[1]);
|
GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE3>(outVToMte3Event_[1]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_V>(stateWritebackMte3ToVEvent_);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_MTE2>(stateWritebackMte3ToMte2Event_);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_MTE3>(specWritebackMte2ToMte3Event_[0]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_MTE3>(specWritebackMte2ToMte3Event_[1]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_MTE2>(specWritebackMte3ToMte2Event_[0]);
|
||||||
|
GetTPipePtr()->ReleaseEventID<HardEvent::MTE3_MTE2>(specWritebackMte3ToMte2Event_[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::LoadWeightAndBias(int32_t c0, int32_t dimTileSize, bool dbg)
|
__aicore__ inline void CausalConv1d<T>::LoadWeightAndBias(int32_t c0, int32_t dimTileSize)
|
||||||
{
|
{
|
||||||
const int32_t dim = tilingData_->dim;
|
const int32_t dim = tilingData_->dim;
|
||||||
const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC;
|
const int32_t width = static_cast<int32_t>(tilingData_->width);
|
||||||
(void)dbgSync;
|
const int32_t jStart = MAX_WIDTH - width;
|
||||||
LocalTensor<float> calc = calcBuf.Get<float>();
|
LocalTensor<float> calc = calcBuf.Get<float>();
|
||||||
LocalTensor<float> weightF = calc;
|
LocalTensor<float> weightF = calc;
|
||||||
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
|
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
|
||||||
LocalTensor<T> tempT = outBuf.Get<T>();
|
const bool hasBias = (tilingData_->hasBias != 0);
|
||||||
|
|
||||||
// CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] c0=%d, dimTileSize=%d\n", c0, dimTileSize);
|
for (int32_t j = 0; j < width; ++j) {
|
||||||
|
const int32_t jDst = jStart + j;
|
||||||
for (int32_t j = 0; j < MAX_WIDTH; ++j) {
|
|
||||||
const int64_t weightOffset = static_cast<int64_t>(j) * dim + c0;
|
const int64_t weightOffset = static_cast<int64_t>(j) * dim + c0;
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
DataCopy(tempT, weightGm[weightOffset], dimTileSize);
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
PipeBarrier<PIPE_ALL>();
|
DataCopy(weightF[jDst * MAX_BLOCK_DIM], weightGm[weightOffset], dimTileSize);
|
||||||
Cast(weightF[j * MAX_BLOCK_DIM], tempT, RoundMode::CAST_NONE, dimTileSize);
|
} else {
|
||||||
PipeBarrier<PIPE_ALL>();
|
DataCopy(weightF.ReinterpretCast<T>()[jDst * MAX_BLOCK_DIM * 2 + MAX_BLOCK_DIM], weightGm[weightOffset], dimTileSize);
|
||||||
// if (dbg && CCONV_DBG_DUMP_WEIGHTS) {
|
}
|
||||||
// CCONV_PRINTF("[Dump][weightF] j=%d\n", j);
|
|
||||||
// CCONV_DUMP_TENSOR_IF(true, weightF[j * MAX_BLOCK_DIM], CCONV_DBG_DUMP_SIZE);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tilingData_->hasBias != 0) {
|
if (hasBias) {
|
||||||
PipeBarrier<PIPE_ALL>();
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
DataCopy(tempT, biasGm[c0], dimTileSize);
|
DataCopy(biasF, biasGm[c0], dimTileSize);
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
Cast(biasF, tempT, RoundMode::CAST_NONE, dimTileSize);
|
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
// if (dbg && CCONV_DBG_DUMP_BIAS) {
|
|
||||||
// CCONV_PRINTF("[Dump][biasF]\n");
|
|
||||||
// CCONV_DUMP_TENSOR_IF(true, biasF, CCONV_DBG_DUMP_SIZE);
|
|
||||||
// }
|
|
||||||
} else {
|
} else {
|
||||||
Duplicate(biasF, 0.0f, dimTileSize);
|
DataCopy(biasF.ReinterpretCast<T>()[MAX_BLOCK_DIM], biasGm[c0], dimTileSize);
|
||||||
// CCONV_PRINT_IF(dbg, "[LoadWeightAndBias] bias=0 (no bias)\n");
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SetFlag<HardEvent::MTE2_V>(weightBiasMte2ToVEvent_);
|
||||||
|
WaitFlag<HardEvent::MTE2_V>(weightBiasMte2ToVEvent_);
|
||||||
|
|
||||||
|
if constexpr (!std::is_same<T, float>::value) {
|
||||||
|
for (int32_t j = 0; j < width; ++j) {
|
||||||
|
const int32_t jDst = jStart + j;
|
||||||
|
Cast(weightF[jDst * MAX_BLOCK_DIM], weightF.ReinterpretCast<T>()[jDst * MAX_BLOCK_DIM * 2 + MAX_BLOCK_DIM],
|
||||||
|
RoundMode::CAST_NONE, dimTileSize);
|
||||||
|
}
|
||||||
|
if (hasBias) {
|
||||||
|
Cast(biasF, biasF.ReinterpretCast<T>()[MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasBias) {
|
||||||
|
Duplicate(biasF, 0.0f, dimTileSize);
|
||||||
}
|
}
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::InitRing(int32_t cacheIdx, bool hasInit, int32_t start, int32_t len,
|
__aicore__ inline void CausalConv1d<T>::InitRing(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset,
|
||||||
int32_t c0, int32_t dimTileSize, int32_t dim, bool dbg)
|
int32_t start, int32_t len, int32_t c0, int32_t dimTileSize,
|
||||||
|
int32_t dim)
|
||||||
{
|
{
|
||||||
const int32_t stateLen = tilingData_->stateLen;
|
const int32_t stateLen = tilingData_->stateLen;
|
||||||
|
const int32_t width = static_cast<int32_t>(tilingData_->width);
|
||||||
|
const int32_t ringStart = MAX_WIDTH - width;
|
||||||
LocalTensor<T> ring = inBuf.Get<T>();
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
|
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
if (hasInit) {
|
if (hasInit) {
|
||||||
for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) {
|
for (int32_t i = 0; i < (width - 1); ++i) {
|
||||||
|
const int32_t pos = stateTokenOffset + i;
|
||||||
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
||||||
static_cast<int64_t>(i) * dim + c0;
|
static_cast<int64_t>(pos) * dim + c0;
|
||||||
DataCopy(ring[i * MAX_BLOCK_DIM], convStatesGm[stateOffset], dimTileSize);
|
DataCopy(ring[(ringStart + i) * MAX_BLOCK_DIM], convStatesGm[stateOffset], dimTileSize);
|
||||||
}
|
}
|
||||||
|
SetFlag<HardEvent::MTE2_V>(stateMte2ToVEvent_);
|
||||||
|
WaitFlag<HardEvent::MTE2_V>(stateMte2ToVEvent_);
|
||||||
} else {
|
} else {
|
||||||
for (int32_t i = 0; i < (MAX_WIDTH - 1); ++i) {
|
for (int32_t i = 0; i < (width - 1); ++i) {
|
||||||
Duplicate(ring[i * MAX_BLOCK_DIM], static_cast<T>(0), dimTileSize);
|
Duplicate(ring[(ringStart + i) * MAX_BLOCK_DIM], static_cast<T>(0), dimTileSize);
|
||||||
}
|
}
|
||||||
|
PipeBarrier<PIPE_V>();
|
||||||
}
|
}
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
|
|
||||||
if (len > 0) {
|
if (len > 0) {
|
||||||
|
const int32_t slot0 = SlotCurr(0);
|
||||||
const int64_t xOffset = static_cast<int64_t>(start) * dim + c0;
|
const int64_t xOffset = static_cast<int64_t>(start) * dim + c0;
|
||||||
PipeBarrier<PIPE_ALL>();
|
DataCopy(ring[slot0 * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
|
||||||
DataCopy(ring[SlotCurr(0) * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
|
SetFlag<HardEvent::MTE2_V>(inputMte2ToVEvent_[slot0]);
|
||||||
PipeBarrier<PIPE_ALL>();
|
}
|
||||||
|
|
||||||
|
if (len > 1) {
|
||||||
|
SetFlag<HardEvent::V_MTE2>(inputVToMte2Event_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize,
|
__aicore__ inline void CausalConv1d<T>::RunSeq(int32_t start, int32_t len, int32_t c0, int32_t dimTileSize,
|
||||||
int32_t dim, bool dbg)
|
int32_t dim)
|
||||||
{
|
{
|
||||||
|
const int32_t width = static_cast<int32_t>(tilingData_->width);
|
||||||
|
const int32_t jStart = MAX_WIDTH - width;
|
||||||
LocalTensor<float> calc = calcBuf.Get<float>();
|
LocalTensor<float> calc = calcBuf.Get<float>();
|
||||||
LocalTensor<float> weightF = calc;
|
LocalTensor<float> weightF = calc;
|
||||||
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
|
LocalTensor<float> biasF = weightF[MAX_WIDTH * MAX_BLOCK_DIM];
|
||||||
@@ -284,78 +265,77 @@ __aicore__ inline void CausalConv1d<T>::RunSeq(int32_t start, int32_t len, int32
|
|||||||
LocalTensor<float> tmpF = accF[MAX_BLOCK_DIM];
|
LocalTensor<float> tmpF = accF[MAX_BLOCK_DIM];
|
||||||
LocalTensor<T> ring = inBuf.Get<T>();
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
LocalTensor<T> outT = outBuf.Get<T>();
|
LocalTensor<T> outT = outBuf.Get<T>();
|
||||||
const bool dbgSync = dbg && CCONV_DBG_PRINT_SYNC;
|
|
||||||
(void)dbgSync;
|
|
||||||
const bool hasActivation = (tilingData_->activationMode != 0);
|
const bool hasActivation = (tilingData_->activationMode != 0);
|
||||||
const int32_t dbgMaxTokens = CCONV_DBG_MAX_TOKENS;
|
|
||||||
const int32_t dbgVerboseTokens = CCONV_DBG_VERBOSE_TOKENS;
|
|
||||||
|
|
||||||
for (int32_t t = 0; t < len; ++t) {
|
for (int32_t t = 0; t < len; ++t) {
|
||||||
const bool dbgTok = dbg && (t < dbgMaxTokens);
|
|
||||||
const bool dbgVerbose = dbg && CCONV_DBG_DUMP_RUNSEQ && (t < dbgVerboseTokens);
|
|
||||||
const bool dbgStep = dbgVerbose && (t == 0);
|
|
||||||
const int32_t slotCurr = SlotCurr(t);
|
const int32_t slotCurr = SlotCurr(t);
|
||||||
const int32_t slotH1 = SlotHist(t, 1);
|
|
||||||
const int32_t slotH2 = SlotHist(t, 2);
|
WaitFlag<HardEvent::MTE2_V>(inputMte2ToVEvent_[slotCurr]);
|
||||||
const int32_t slotH3 = SlotHist(t, 3);
|
|
||||||
const int32_t slotPref = (t + 1 < len) ? SlotPrefetch(t) : -1;
|
|
||||||
const int32_t outSlot = t & 1;
|
|
||||||
|
|
||||||
if (t + 1 < len) {
|
if (t + 1 < len) {
|
||||||
const int64_t xOffset = static_cast<int64_t>(start + t + 1) * dim + c0;
|
const int32_t slotNext = SlotPrefetch(t);
|
||||||
PipeBarrier<PIPE_ALL>();
|
const int64_t xOffsetNext = static_cast<int64_t>(start + t + 1) * dim + c0;
|
||||||
DataCopy(ring[slotPref * MAX_BLOCK_DIM], xGm[xOffset], dimTileSize);
|
WaitFlag<HardEvent::V_MTE2>(inputVToMte2Event_);
|
||||||
PipeBarrier<PIPE_ALL>();
|
DataCopy(ring[slotNext * MAX_BLOCK_DIM], xGm[xOffsetNext], dimTileSize);
|
||||||
|
SetFlag<HardEvent::MTE2_V>(inputMte2ToVEvent_[slotNext]);
|
||||||
}
|
}
|
||||||
|
|
||||||
DataCopy(accF, biasF, dimTileSize);
|
DataCopy(accF, biasF, dimTileSize);
|
||||||
|
PipeBarrier<PIPE_V>();
|
||||||
|
|
||||||
|
for (int32_t j = jStart; j < MAX_WIDTH; ++j) {
|
||||||
for (int32_t j = 0; j < MAX_WIDTH; ++j) {
|
|
||||||
const int32_t tap = (MAX_WIDTH - 1) - j;
|
const int32_t tap = (MAX_WIDTH - 1) - j;
|
||||||
const int32_t slot = (tap == 0) ? slotCurr : SlotHist(t, tap);
|
const int32_t slot = (tap == 0) ? slotCurr : SlotHist(t, tap);
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
Cast(tmpF, ring[slot * MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize);
|
Cast(tmpF, ring[slot * MAX_BLOCK_DIM], RoundMode::CAST_NONE, dimTileSize);
|
||||||
PipeBarrier<PIPE_ALL>();
|
// PipeBarrier<PIPE_V>();
|
||||||
|
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
MulAddDst(accF, tmpF, weightF[j * MAX_BLOCK_DIM], dimTileSize);
|
MulAddDst(accF, tmpF, weightF[j * MAX_BLOCK_DIM], dimTileSize);
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasActivation) {
|
if (hasActivation) {
|
||||||
Silu(tmpF, accF, dimTileSize);
|
Silu(tmpF, accF, dimTileSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
PipeBarrier<PIPE_ALL>();
|
const int32_t outSlot = t & 1;
|
||||||
|
LocalTensor<T> outSlotT = outT[outSlot * MAX_BLOCK_DIM];
|
||||||
|
if (t >= 2) {
|
||||||
|
WaitFlag<HardEvent::MTE3_V>(outMte3ToVEvent_[outSlot]);
|
||||||
|
}
|
||||||
if constexpr (IsSameType<T, float>::value) {
|
if constexpr (IsSameType<T, float>::value) {
|
||||||
if (hasActivation) {
|
if (hasActivation) {
|
||||||
DataCopy(outT[outSlot * MAX_BLOCK_DIM], tmpF, dimTileSize);
|
DataCopy(outSlotT, tmpF, dimTileSize);
|
||||||
} else {
|
} else {
|
||||||
DataCopy(outT[outSlot * MAX_BLOCK_DIM], accF, dimTileSize);
|
DataCopy(outSlotT, accF, dimTileSize);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (hasActivation) {
|
if (hasActivation) {
|
||||||
Cast(outT[outSlot * MAX_BLOCK_DIM], tmpF, RoundMode::CAST_RINT, dimTileSize);
|
Cast(outSlotT, tmpF, RoundMode::CAST_RINT, dimTileSize);
|
||||||
} else {
|
} else {
|
||||||
Cast(outT[outSlot * MAX_BLOCK_DIM], accF, RoundMode::CAST_RINT, dimTileSize);
|
Cast(outSlotT, accF, RoundMode::CAST_RINT, dimTileSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
|
SetFlag<HardEvent::V_MTE3>(outVToMte3Event_[outSlot]);
|
||||||
|
|
||||||
const int64_t outOffset = static_cast<int64_t>(start + t) * dim + c0;
|
const int64_t outOffset = static_cast<int64_t>(start + t) * dim + c0;
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
DataCopy(yGm[outOffset], outT[outSlot * MAX_BLOCK_DIM], dimTileSize);
|
WaitFlag<HardEvent::V_MTE3>(outVToMte3Event_[outSlot]);
|
||||||
PipeBarrier<PIPE_ALL>();
|
DataCopy(yGm[outOffset], outSlotT, dimTileSize);
|
||||||
|
if (t + 2 < len) {
|
||||||
|
SetFlag<HardEvent::MTE3_V>(outMte3ToVEvent_[outSlot]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (t + 2 < len) {
|
||||||
|
SetFlag<HardEvent::V_MTE2>(inputVToMte2Event_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__aicore__ inline void CausalConv1d<T>::WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
|
__aicore__ inline void CausalConv1d<T>::WriteBackState(int32_t cacheIdx, int32_t len, int32_t c0,
|
||||||
int32_t dimTileSize, int32_t dim, bool dbg)
|
int32_t dimTileSize, int32_t dim)
|
||||||
{
|
{
|
||||||
const int32_t stateLen = tilingData_->stateLen;
|
const int32_t stateLen = tilingData_->stateLen;
|
||||||
|
const int32_t width = static_cast<int32_t>(tilingData_->width);
|
||||||
if (len <= 0) {
|
if (len <= 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -363,14 +343,95 @@ __aicore__ inline void CausalConv1d<T>::WriteBackState(int32_t cacheIdx, int32_t
|
|||||||
const int32_t lastT = len - 1;
|
const int32_t lastT = len - 1;
|
||||||
LocalTensor<T> ring = inBuf.Get<T>();
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
|
|
||||||
for (int32_t pos = 0; pos < (MAX_WIDTH - 1); ++pos) {
|
for (int32_t pos = 0; pos < (width - 1); ++pos) {
|
||||||
const int32_t tap = (MAX_WIDTH - 2) - pos;
|
const int32_t tap = (width - 2) - pos;
|
||||||
const int32_t slot = (tap == 0) ? SlotCurr(lastT) : SlotHist(lastT, tap);
|
const int32_t slot = (tap == 0) ? SlotCurr(lastT) : SlotHist(lastT, tap);
|
||||||
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
const int64_t stateOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
||||||
static_cast<int64_t>(pos) * dim + c0;
|
static_cast<int64_t>(pos) * dim + c0;
|
||||||
PipeBarrier<PIPE_ALL>();
|
|
||||||
DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize);
|
DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize);
|
||||||
PipeBarrier<PIPE_ALL>();
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void CausalConv1d<T>::WriteBackStateSpec(int32_t cacheIdx, bool hasInit, int32_t stateTokenOffset,
|
||||||
|
int32_t start, int32_t len, int32_t c0,
|
||||||
|
int32_t dimTileSize, int32_t dim)
|
||||||
|
{
|
||||||
|
const int32_t width = static_cast<int32_t>(tilingData_->width);
|
||||||
|
const int32_t stateLen = tilingData_->stateLen;
|
||||||
|
if (len <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (width != 4) {
|
||||||
|
WriteBackState(cacheIdx, len, c0, dimTileSize, dim);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int32_t keep = MAX_WIDTH - 2;
|
||||||
|
const int32_t reqStateLen = keep + len;
|
||||||
|
if (reqStateLen > stateLen) {
|
||||||
|
WriteBackState(cacheIdx, len, c0, dimTileSize, dim);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LocalTensor<T> ring = inBuf.Get<T>();
|
||||||
|
LocalTensor<T> buf0 = ring[0 * MAX_BLOCK_DIM];
|
||||||
|
LocalTensor<T> buf1 = ring[1 * MAX_BLOCK_DIM];
|
||||||
|
|
||||||
|
if (hasInit) {
|
||||||
|
const int32_t srcPos0 = stateTokenOffset + 1;
|
||||||
|
const int32_t srcPos1 = stateTokenOffset + 2;
|
||||||
|
const int64_t srcOffset0 = static_cast<int64_t>(cacheIdx) * stateLen * dim + static_cast<int64_t>(srcPos0) * dim + c0;
|
||||||
|
const int64_t srcOffset1 = static_cast<int64_t>(cacheIdx) * stateLen * dim + static_cast<int64_t>(srcPos1) * dim + c0;
|
||||||
|
DataCopy(buf0, convStatesGm[srcOffset0], dimTileSize);
|
||||||
|
DataCopy(buf1, convStatesGm[srcOffset1], dimTileSize);
|
||||||
|
PipeBarrier<PIPE_MTE2>();
|
||||||
|
const int64_t dstOffset0 = static_cast<int64_t>(cacheIdx) * stateLen * dim + static_cast<int64_t>(0) * dim + c0;
|
||||||
|
const int64_t dstOffset1 = static_cast<int64_t>(cacheIdx) * stateLen * dim + static_cast<int64_t>(1) * dim + c0;
|
||||||
|
DataCopy(convStatesGm[dstOffset0], buf0, dimTileSize);
|
||||||
|
DataCopy(convStatesGm[dstOffset1], buf1, dimTileSize);
|
||||||
|
PipeBarrier<PIPE_MTE3>();
|
||||||
|
} else {
|
||||||
|
Duplicate(buf0, static_cast<T>(0), dimTileSize);
|
||||||
|
PipeBarrier<PIPE_V>();
|
||||||
|
const int64_t dstOffset0 = static_cast<int64_t>(cacheIdx) * stateLen * dim + static_cast<int64_t>(0) * dim + c0;
|
||||||
|
const int64_t dstOffset1 = static_cast<int64_t>(cacheIdx) * stateLen * dim + static_cast<int64_t>(1) * dim + c0;
|
||||||
|
DataCopy(convStatesGm[dstOffset0], buf0, dimTileSize);
|
||||||
|
DataCopy(convStatesGm[dstOffset1], buf0, dimTileSize);
|
||||||
|
PipeBarrier<PIPE_MTE3>();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t xOffset0 = static_cast<int64_t>(start) * dim + c0;
|
||||||
|
DataCopy(buf0, xGm[xOffset0], dimTileSize);
|
||||||
|
SetFlag<HardEvent::MTE2_MTE3>(specWritebackMte2ToMte3Event_[0]);
|
||||||
|
|
||||||
|
for (int32_t t = 0; t < len; ++t) {
|
||||||
|
const int32_t curr = t & 1;
|
||||||
|
const int32_t next = curr ^ 1;
|
||||||
|
LocalTensor<T> currBuf = (curr == 0) ? buf0 : buf1;
|
||||||
|
LocalTensor<T> nextBuf = (next == 0) ? buf0 : buf1;
|
||||||
|
|
||||||
|
WaitFlag<HardEvent::MTE2_MTE3>(specWritebackMte2ToMte3Event_[curr]);
|
||||||
|
|
||||||
|
if (t + 1 < len) {
|
||||||
|
const int64_t xOffsetNext = static_cast<int64_t>(start + t + 1) * dim + c0;
|
||||||
|
if (t > 0) {
|
||||||
|
WaitFlag<HardEvent::MTE3_MTE2>(specWritebackMte3ToMte2Event_[next]);
|
||||||
|
}
|
||||||
|
DataCopy(nextBuf, xGm[xOffsetNext], dimTileSize);
|
||||||
|
SetFlag<HardEvent::MTE2_MTE3>(specWritebackMte2ToMte3Event_[next]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t dstOffset = static_cast<int64_t>(cacheIdx) * stateLen * dim +
|
||||||
|
static_cast<int64_t>(keep + t) * dim + c0;
|
||||||
|
DataCopy(convStatesGm[dstOffset], currBuf, dimTileSize);
|
||||||
|
SetFlag<HardEvent::MTE3_MTE2>(specWritebackMte3ToMte2Event_[curr]);
|
||||||
|
}
|
||||||
|
|
||||||
|
WaitFlag<HardEvent::MTE3_MTE2>(specWritebackMte3ToMte2Event_[0]);
|
||||||
|
if (len > 1) {
|
||||||
|
WaitFlag<HardEvent::MTE3_MTE2>(specWritebackMte3ToMte2Event_[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,11 +444,14 @@ __aicore__ inline void CausalConv1d<T>::Process()
|
|||||||
const int32_t seqLen = tilingData_->seqLen;
|
const int32_t seqLen = tilingData_->seqLen;
|
||||||
const int32_t dimTileSize = static_cast<int32_t>(tilingData_->dimTileSize);
|
const int32_t dimTileSize = static_cast<int32_t>(tilingData_->dimTileSize);
|
||||||
const int32_t blocksPerSeq = static_cast<int32_t>(tilingData_->blocksPerSeq);
|
const int32_t blocksPerSeq = static_cast<int32_t>(tilingData_->blocksPerSeq);
|
||||||
|
const int32_t width = static_cast<int32_t>(tilingData_->width);
|
||||||
|
const bool isSpecDecodingGlobal =
|
||||||
|
(tilingData_->runMode == 1) && (tilingData_->hasNumAcceptedTokens != 0) && (width == 4);
|
||||||
|
|
||||||
const uint32_t blockIdx = GetBlockIdx();
|
const uint32_t blockIdx = GetBlockIdx();
|
||||||
const uint32_t blockNum = GetBlockNum();
|
const uint32_t blockNum = GetBlockNum();
|
||||||
|
|
||||||
if (dimTileSize <= 0 || blocksPerSeq <= 0 || dimTileSize > MAX_BLOCK_DIM || blocksPerSeq * dimTileSize != dim) {
|
if (dimTileSize <= 0 || blocksPerSeq <= 0 || dimTileSize > MAX_BLOCK_DIM || width < 2 || width > MAX_WIDTH) {
|
||||||
ReleaseEvents();
|
ReleaseEvents();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -397,9 +461,10 @@ __aicore__ inline void CausalConv1d<T>::Process()
|
|||||||
const int32_t seq = static_cast<int32_t>(task / blocksPerSeq);
|
const int32_t seq = static_cast<int32_t>(task / blocksPerSeq);
|
||||||
const int32_t dimBlockId = static_cast<int32_t>(task % blocksPerSeq);
|
const int32_t dimBlockId = static_cast<int32_t>(task % blocksPerSeq);
|
||||||
const int32_t c0 = dimBlockId * dimTileSize;
|
const int32_t c0 = dimBlockId * dimTileSize;
|
||||||
const bool dbg = (seq == CCONV_DBG_SEQ) && (c0 == CCONV_DBG_C0);
|
if (c0 >= dim) {
|
||||||
|
continue;
|
||||||
LoadWeightAndBias(c0, dimTileSize, dbg);
|
}
|
||||||
|
const int32_t dimTileSizeActual = (c0 + dimTileSize <= dim) ? dimTileSize : (dim - c0);
|
||||||
|
|
||||||
int32_t start = 0;
|
int32_t start = 0;
|
||||||
int32_t len = 0;
|
int32_t len = 0;
|
||||||
@@ -408,6 +473,9 @@ __aicore__ inline void CausalConv1d<T>::Process()
|
|||||||
const int32_t endVal = queryStartLocGm.GetValue(seq + 1);
|
const int32_t endVal = queryStartLocGm.GetValue(seq + 1);
|
||||||
start = startVal;
|
start = startVal;
|
||||||
len = endVal - startVal;
|
len = endVal - startVal;
|
||||||
|
} else if (inputMode == 2) {
|
||||||
|
start = seq;
|
||||||
|
len = 1;
|
||||||
} else {
|
} else {
|
||||||
start = seq * seqLen;
|
start = seq * seqLen;
|
||||||
len = seqLen;
|
len = seqLen;
|
||||||
@@ -417,16 +485,55 @@ __aicore__ inline void CausalConv1d<T>::Process()
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int32_t cacheIdx = cacheIndicesGm.GetValue(seq);
|
int32_t cacheIdx = seq;
|
||||||
if (cacheIdx == tilingData_->padSlotId) {
|
if (tilingData_->hasCacheIndices != 0) {
|
||||||
|
const int64_t cacheIdx64 = cacheIndicesGm.GetValue(seq);
|
||||||
|
if (cacheIdx64 == tilingData_->padSlotId) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
cacheIdx = static_cast<int32_t>(cacheIdx64);
|
||||||
|
}
|
||||||
|
|
||||||
const bool hasInit = hasInitialStateGm.GetValue(seq);
|
const bool hasInit =
|
||||||
|
(tilingData_->hasInitialStateMode != 0) ? (initialStateModeGm.GetValue(seq) != 0) : false;
|
||||||
|
int32_t stateTokenOffset = 0;
|
||||||
|
if (isSpecDecodingGlobal) {
|
||||||
|
int32_t accepted = static_cast<int32_t>(numAcceptedTokensGm.GetValue(seq));
|
||||||
|
stateTokenOffset = accepted - 1;
|
||||||
|
const int32_t maxOffset = static_cast<int32_t>(tilingData_->stateLen - (width - 1));
|
||||||
|
if (stateTokenOffset < 0) {
|
||||||
|
stateTokenOffset = 0;
|
||||||
|
} else if (stateTokenOffset > maxOffset) {
|
||||||
|
stateTokenOffset = maxOffset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
InitRing(cacheIdx, hasInit, start, len, c0, dimTileSize, dim, dbg);
|
const bool weightCacheHit =
|
||||||
RunSeq(start, len, c0, dimTileSize, dim, dbg);
|
weightCacheValid_ && (cachedC0_ == c0) && (cachedDimTileSize_ == dimTileSizeActual);
|
||||||
WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg);
|
if (!weightCacheHit) {
|
||||||
|
LoadWeightAndBias(c0, dimTileSizeActual);
|
||||||
|
weightCacheValid_ = true;
|
||||||
|
cachedC0_ = c0;
|
||||||
|
cachedDimTileSize_ = dimTileSizeActual;
|
||||||
|
}
|
||||||
|
|
||||||
|
InitRing(cacheIdx, hasInit, stateTokenOffset, start, len, c0, dimTileSizeActual, dim);
|
||||||
|
RunSeq(start, len, c0, dimTileSizeActual, dim);
|
||||||
|
|
||||||
|
SetFlag<HardEvent::MTE3_V>(stateWritebackMte3ToVEvent_);
|
||||||
|
WaitFlag<HardEvent::MTE3_V>(stateWritebackMte3ToVEvent_);
|
||||||
|
SetFlag<HardEvent::MTE3_MTE2>(stateWritebackMte3ToMte2Event_);
|
||||||
|
WaitFlag<HardEvent::MTE3_MTE2>(stateWritebackMte3ToMte2Event_);
|
||||||
|
|
||||||
|
if (isSpecDecodingGlobal) {
|
||||||
|
WriteBackStateSpec(cacheIdx, hasInit, stateTokenOffset, start, len, c0, dimTileSizeActual, dim);
|
||||||
|
} else {
|
||||||
|
WriteBackState(cacheIdx, len, c0, dimTileSizeActual, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
PipeBarrier<PIPE_V>();
|
||||||
|
PipeBarrier<PIPE_MTE2>();
|
||||||
|
PipeBarrier<PIPE_MTE3>();
|
||||||
}
|
}
|
||||||
|
|
||||||
ReleaseEvents();
|
ReleaseEvents();
|
||||||
|
|||||||
49
csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_data.h
Normal file
49
csrc/causal_conv1d/op_kernel/causal_conv1d_tiling_data.h
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
/**
|
||||||
|
* This program is free software, you can redistribute it and/or modify it.
|
||||||
|
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||||
|
* This file is a part of the CANN Open Software.
|
||||||
|
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
|
||||||
|
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||||
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||||
|
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \file causal_conv1d_tiling_data.h
|
||||||
|
* \brief tiling data struct
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef CAUSAL_CONV1D_TILING_DATA_H_
|
||||||
|
#define CAUSAL_CONV1D_TILING_DATA_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
struct CausalConv1dTilingData {
|
||||||
|
int64_t dim;
|
||||||
|
int64_t cuSeqlen;
|
||||||
|
int64_t seqLen;
|
||||||
|
int64_t inputMode;
|
||||||
|
int64_t runMode;
|
||||||
|
|
||||||
|
int64_t width;
|
||||||
|
|
||||||
|
int64_t stateLen;
|
||||||
|
int64_t numCacheLines;
|
||||||
|
|
||||||
|
int64_t batch;
|
||||||
|
|
||||||
|
int64_t activationMode;
|
||||||
|
int64_t padSlotId;
|
||||||
|
|
||||||
|
int64_t hasBias;
|
||||||
|
|
||||||
|
int64_t dimTileSize;
|
||||||
|
int64_t blocksPerSeq;
|
||||||
|
|
||||||
|
int64_t hasNumAcceptedTokens;
|
||||||
|
|
||||||
|
int64_t hasCacheIndices;
|
||||||
|
int64_t hasInitialStateMode;
|
||||||
|
};
|
||||||
|
#endif // CAUSAL_CONV1D_TILING_DATA_H_
|
||||||
@@ -633,38 +633,32 @@ npu_copy_and_expand_eagle_inputs(
|
|||||||
out_new_token_indices, out_hidden_state_mapping};
|
out_new_token_indices, out_hidden_state_mapping};
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor causal_conv1d_fn(
|
at::Tensor npu_causal_conv1d_custom(
|
||||||
const at::Tensor& mixed_qkv_non_spec_T,
|
const at::Tensor& x,
|
||||||
const at::Tensor& conv_weights,
|
const at::Tensor& weight,
|
||||||
const c10::optional<at::Tensor>& bias_opt,
|
|
||||||
c10::string_view activation,
|
|
||||||
const at::Tensor& conv_state,
|
const at::Tensor& conv_state,
|
||||||
const at::Tensor& has_initial_state,
|
const c10::optional<at::Tensor>& bias_opt,
|
||||||
const at::Tensor& non_spec_state_indices_tensor,
|
at::IntArrayRef query_start_loc_opt,
|
||||||
const at::Tensor& non_spec_query_start_loc,
|
at::IntArrayRef cache_indices_opt,
|
||||||
int64_t pad_slot_id)
|
at::IntArrayRef initial_state_mode_opt,
|
||||||
|
at::IntArrayRef num_accepted_tokens_opt,
|
||||||
|
int64_t activation_mode,
|
||||||
|
int64_t pad_slot_id,
|
||||||
|
int64_t run_mode)
|
||||||
{
|
{
|
||||||
at::Tensor x=mixed_qkv_non_spec_T; //不需要转置
|
at::Tensor output = at::empty(x.sizes(), x.options());
|
||||||
at::Tensor weight=conv_weights;//不需要转置
|
|
||||||
c10::optional<at::Tensor> biasOptional =bias_opt;
|
|
||||||
at::Tensor convStates= conv_state;
|
|
||||||
at::Tensor queryStartLoc=non_spec_query_start_loc;
|
|
||||||
at::Tensor cacheIndices=non_spec_state_indices_tensor;
|
|
||||||
at::Tensor hasInitialState=has_initial_state;
|
|
||||||
int64_t activationMode=(activation.empty()?0:1);
|
|
||||||
int64_t padSlotId=pad_slot_id;
|
|
||||||
|
|
||||||
at::Tensor output = at::empty(mixed_qkv_non_spec_T.sizes(), mixed_qkv_non_spec_T.options());
|
|
||||||
EXEC_NPU_CMD(aclnnCausalConv1d,
|
EXEC_NPU_CMD(aclnnCausalConv1d,
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
biasOptional,
|
bias_opt,
|
||||||
convStates,
|
conv_state,
|
||||||
queryStartLoc,
|
query_start_loc_opt,
|
||||||
cacheIndices,
|
cache_indices_opt,
|
||||||
hasInitialState,
|
initial_state_mode_opt,
|
||||||
activationMode,
|
num_accepted_tokens_opt,
|
||||||
padSlotId,
|
activation_mode,
|
||||||
|
pad_slot_id,
|
||||||
|
run_mode,
|
||||||
output
|
output
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -895,18 +889,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
"Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)"
|
"Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)"
|
||||||
);
|
);
|
||||||
ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs);
|
ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs);
|
||||||
// causal_conv1d_fn
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "
|
"npu_causal_conv1d_custom(Tensor x, "
|
||||||
" Tensor conv_weights, "
|
" Tensor weight, "
|
||||||
" Tensor? bias_opt, "
|
|
||||||
" str activation, "
|
|
||||||
" Tensor conv_state, "
|
" Tensor conv_state, "
|
||||||
" Tensor has_initial_state, "
|
" Tensor? bias_opt, "
|
||||||
" Tensor non_spec_state_indices_tensor, "
|
" int[] query_start_loc_opt, "
|
||||||
" Tensor non_spec_query_start_loc, "
|
" int[] cache_indices_opt, "
|
||||||
" int pad_slot_id) -> (Tensor output)");
|
" int[] initial_state_mode_opt, "
|
||||||
ops.impl("causal_conv1d_fn", torch::kPrivateUse1, &vllm_ascend::causal_conv1d_fn);
|
" int[] num_accepted_tokens_opt, "
|
||||||
|
" int activation_mode, "
|
||||||
|
" int pad_slot_id, "
|
||||||
|
" int run_mode"
|
||||||
|
") -> (Tensor output)");
|
||||||
|
ops.impl("npu_causal_conv1d_custom", torch::kPrivateUse1, &vllm_ascend::npu_causal_conv1d_custom);
|
||||||
ops.def(
|
ops.def(
|
||||||
"moe_grouped_matmul("
|
"moe_grouped_matmul("
|
||||||
"Tensor x,"
|
"Tensor x,"
|
||||||
|
|||||||
@@ -485,19 +485,21 @@ npu_copy_and_expand_eagle_inputs_meta(
|
|||||||
out_new_token_indices, out_hidden_state_mapping};
|
out_new_token_indices, out_hidden_state_mapping};
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor causal_conv1d_fn_meta(
|
at::Tensor npu_causal_conv1d_custom_meta(
|
||||||
const at::Tensor& mixed_qkv_non_spec_T,
|
const at::Tensor& x,
|
||||||
const at::Tensor& conv_weights,
|
const at::Tensor& weight,
|
||||||
const c10::optional<at::Tensor>& bias_opt,
|
|
||||||
c10::string_view activation,
|
|
||||||
const at::Tensor& conv_state,
|
const at::Tensor& conv_state,
|
||||||
const at::Tensor& has_initial_state,
|
const c10::optional<at::Tensor>& bias_opt,
|
||||||
const at::Tensor& non_spec_state_indices_tensor,
|
at::IntArrayRef query_start_loc_opt,
|
||||||
const at::Tensor& non_spec_query_start_loc,
|
at::IntArrayRef cache_indices_opt,
|
||||||
int64_t pad_slot_id)
|
at::IntArrayRef initial_state_mode_opt,
|
||||||
|
at::IntArrayRef num_accepted_tokens_opt,
|
||||||
|
int64_t activation_mode,
|
||||||
|
int64_t pad_slot_id,
|
||||||
|
int64_t run_mode)
|
||||||
{
|
{
|
||||||
|
|
||||||
at::Tensor output = at::empty_symint(mixed_qkv_non_spec_T.sym_sizes(), mixed_qkv_non_spec_T.options());
|
at::Tensor output = at::empty_symint(x.sym_sizes(), x.options());
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -611,7 +613,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
|||||||
// CopyAndExpandEagleInputs
|
// CopyAndExpandEagleInputs
|
||||||
ops.impl("npu_copy_and_expand_eagle_inputs", &vllm_ascend::meta::npu_copy_and_expand_eagle_inputs_meta);
|
ops.impl("npu_copy_and_expand_eagle_inputs", &vllm_ascend::meta::npu_copy_and_expand_eagle_inputs_meta);
|
||||||
// causal_conv1d_fn
|
// causal_conv1d_fn
|
||||||
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
|
ops.impl("npu_causal_conv1d_custom", &vllm_ascend::meta::npu_causal_conv1d_custom_meta);
|
||||||
// moe_grouped_matmul
|
// moe_grouped_matmul
|
||||||
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
||||||
// Lightning indexer quant
|
// Lightning indexer quant
|
||||||
|
|||||||
@@ -157,6 +157,11 @@ def causal_conv1d_fn_pytorch(
|
|||||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||||
return out_ref_tensor
|
return out_ref_tensor
|
||||||
|
|
||||||
|
def to_int64_tuple(t):
|
||||||
|
t = t.to(torch.int64)
|
||||||
|
if t.dim() == 0:
|
||||||
|
return (t.item(),)
|
||||||
|
return tuple(t.tolist())
|
||||||
|
|
||||||
@pytest.mark.parametrize('has_initial_state', [False, True])
|
@pytest.mark.parametrize('has_initial_state', [False, True])
|
||||||
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
||||||
@@ -227,16 +232,19 @@ def test_ascend_causal_conv1d(dim, width, extra_state_len, seq_len, has_bias,
|
|||||||
x_origin=x.transpose(-1, -2)
|
x_origin=x.transpose(-1, -2)
|
||||||
weight_origin=weight.transpose(-1, -2)
|
weight_origin=weight.transpose(-1, -2)
|
||||||
conv_states_origin=conv_states.transpose(-1, -2)
|
conv_states_origin=conv_states.transpose(-1, -2)
|
||||||
out = torch.ops._C_ascend.causal_conv1d_fn(
|
activation_num = 1 if activation else 0
|
||||||
|
out = torch.ops._C_ascend.npu_causal_conv1d_custom(
|
||||||
x_origin,
|
x_origin,
|
||||||
weight_origin,
|
weight_origin,
|
||||||
bias,
|
|
||||||
activation=activation,
|
|
||||||
conv_state=conv_states_origin,
|
conv_state=conv_states_origin,
|
||||||
has_initial_state=has_initial_state_tensor,
|
bias_opt=bias,
|
||||||
non_spec_state_indices_tensor=cache_indices,
|
query_start_loc_opt=to_int64_tuple(query_start_loc),
|
||||||
non_spec_query_start_loc=query_start_loc,
|
cache_indices_opt=to_int64_tuple(cache_indices),
|
||||||
|
initial_state_mode_opt=to_int64_tuple(has_initial_state_tensor),
|
||||||
|
num_accepted_tokens_opt=[],
|
||||||
|
activation_mode=activation_num,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
run_mode=0
|
||||||
).transpose(-1, -2)
|
).transpose(-1, -2)
|
||||||
validate_cmp(out, out_ref, itype)
|
validate_cmp(out, out_ref, itype)
|
||||||
validate_cmp(conv_states, conv_states_ref, itype)
|
validate_cmp(conv_states, conv_states_ref, itype)
|
||||||
|
|||||||
@@ -33,6 +33,13 @@ from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
|||||||
from vllm_ascend.utils import enable_sp, vllm_version_is
|
from vllm_ascend.utils import enable_sp, vllm_version_is
|
||||||
|
|
||||||
|
|
||||||
|
def to_int64_tuple(t):
|
||||||
|
t = t.to(torch.int64)
|
||||||
|
if t.dim() == 0:
|
||||||
|
return (t.item(),)
|
||||||
|
return tuple(t.tolist())
|
||||||
|
|
||||||
|
|
||||||
class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
||||||
def _forward_core(
|
def _forward_core(
|
||||||
self,
|
self,
|
||||||
@@ -110,16 +117,19 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
|||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
if mixed_qkv_non_spec is not None:
|
if mixed_qkv_non_spec is not None:
|
||||||
conv_weights_T = conv_weights.transpose(0, 1)
|
conv_weights_T = conv_weights.transpose(0, 1)
|
||||||
mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn(
|
activation_num = 1 if self.activation else 0
|
||||||
|
mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom(
|
||||||
mixed_qkv_non_spec,
|
mixed_qkv_non_spec,
|
||||||
conv_weights_T,
|
conv_weights_T,
|
||||||
self.conv1d.bias,
|
|
||||||
activation=self.activation,
|
|
||||||
conv_state=self_kv_cache[0],
|
conv_state=self_kv_cache[0],
|
||||||
has_initial_state=has_initial_state,
|
bias_opt=self.conv1d.bias,
|
||||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc),
|
||||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor),
|
||||||
|
initial_state_mode_opt=to_int64_tuple(has_initial_state),
|
||||||
|
num_accepted_tokens_opt=[],
|
||||||
|
activation_mode=activation_num,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
run_mode=0,
|
||||||
)
|
)
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
mixed_qkv_non_spec = causal_conv1d_update(
|
mixed_qkv_non_spec = causal_conv1d_update(
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
|||||||
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
|
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
|
||||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||||
|
from vllm_ascend.patch.worker.patch_qwen3_5 import to_int64_tuple
|
||||||
from vllm_ascend.utils import enable_sp, vllm_version_is
|
from vllm_ascend.utils import enable_sp, vllm_version_is
|
||||||
|
|
||||||
|
|
||||||
@@ -167,16 +168,19 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
|||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
if mixed_qkv_non_spec is not None:
|
if mixed_qkv_non_spec is not None:
|
||||||
conv_weights_T = conv_weights.transpose(0, 1)
|
conv_weights_T = conv_weights.transpose(0, 1)
|
||||||
mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn(
|
activation_num = 1 if self.activation else 0
|
||||||
|
mixed_qkv_non_spec = torch.ops._C_ascend.npu_causal_conv1d_custom(
|
||||||
mixed_qkv_non_spec,
|
mixed_qkv_non_spec,
|
||||||
conv_weights_T,
|
conv_weights_T,
|
||||||
self.conv1d.bias,
|
|
||||||
activation=self.activation,
|
|
||||||
conv_state=self_kv_cache[0],
|
conv_state=self_kv_cache[0],
|
||||||
has_initial_state=has_initial_state,
|
bias_opt=self.conv1d.bias,
|
||||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
query_start_loc_opt=to_int64_tuple(non_spec_query_start_loc),
|
||||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
cache_indices_opt=to_int64_tuple(non_spec_state_indices_tensor),
|
||||||
|
initial_state_mode_opt=to_int64_tuple(has_initial_state),
|
||||||
|
num_accepted_tokens_opt=[],
|
||||||
|
activation_mode=activation_num,
|
||||||
pad_slot_id=PAD_SLOT_ID,
|
pad_slot_id=PAD_SLOT_ID,
|
||||||
|
run_mode=0,
|
||||||
)
|
)
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
mixed_qkv_non_spec = causal_conv1d_update(
|
mixed_qkv_non_spec = causal_conv1d_update(
|
||||||
|
|||||||
Reference in New Issue
Block a user